diff --git a/.golangci.yml b/.golangci.yml index 5a0ee2f..dd4fdd1 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -70,6 +70,10 @@ linters: # guidelines. See https://github.com/mvdan/gofumpt/issues/235. - gofumpt + # Disable gomnd even though we generally don't use magic numbers, but there + # are exceptions where this improves readability. + - gomnd + # Disable whitespace linter as it has conflict rules against our # contribution guidelines. See https://github.com/bombsimon/wsl/issues/109. # diff --git a/gbn/config.go b/gbn/config.go index 4d14f7f..6a027b4 100644 --- a/gbn/config.go +++ b/gbn/config.go @@ -38,6 +38,10 @@ type config struct { // packet. sendToStream sendBytesFunc + // onFIN is a callback that if set, will be called once a FIN packet has + // been received and processed. + onFIN func() + // handshakeTimeout is the time after which the server or client // will abort and restart the handshake if the expected response is // not received from the peer. diff --git a/gbn/gbn_conn.go b/gbn/gbn_conn.go index e1a8e80..50a0cc9 100644 --- a/gbn/gbn_conn.go +++ b/gbn/gbn_conn.go @@ -89,13 +89,10 @@ func newGoBackNConn(ctx context.Context, cfg *config, prefix := fmt.Sprintf("(%s)", loggerPrefix) plog := build.NewPrefixLog(prefix, log) - return &GoBackNConn{ - cfg: cfg, - recvDataChan: make(chan *PacketData, cfg.n), - sendDataChan: make(chan *PacketData), - sendQueue: newQueue( - cfg.n+1, defaultHandshakeTimeout, plog, - ), + g := &GoBackNConn{ + cfg: cfg, + recvDataChan: make(chan *PacketData, cfg.n), + sendDataChan: make(chan *PacketData), recvTimeout: DefaultRecvTimeout, sendTimeout: DefaultSendTimeout, receivedACKSignal: make(chan struct{}), @@ -106,6 +103,17 @@ func newGoBackNConn(ctx context.Context, cfg *config, log: plog, quit: make(chan struct{}), } + + g.sendQueue = newQueue(&queueCfg{ + s: cfg.n + 1, + timeout: cfg.resendTimeout, + log: plog, + sendPkt: func(packet *PacketData) error { + return g.sendPacket(g.ctx, packet) + }, + }) + + return g } // setN sets the current N to use. This _must_ be set before the handshake is @@ -114,7 +122,14 @@ func (g *GoBackNConn) setN(n uint8) { g.cfg.n = n g.cfg.s = n + 1 g.recvDataChan = make(chan *PacketData, n) - g.sendQueue = newQueue(n+1, defaultHandshakeTimeout, g.log) + g.sendQueue = newQueue(&queueCfg{ + s: n + 1, + timeout: g.cfg.resendTimeout, + log: g.log, + sendPkt: func(packet *PacketData) error { + return g.sendPacket(g.ctx, packet) + }, + }) } // SetSendTimeout sets the timeout used in the Send function. @@ -320,6 +335,8 @@ func (g *GoBackNConn) Close() error { // initialisation. g.cancel() + g.sendQueue.stop() + g.wg.Wait() if g.pingTicker != nil { @@ -359,9 +376,28 @@ func (g *GoBackNConn) sendPacket(ctx context.Context, msg Message) error { func (g *GoBackNConn) sendPacketsForever() error { // resendQueue re-sends the current contents of the queue. resendQueue := func() error { - return g.sendQueue.resend(func(packet *PacketData) error { - return g.sendPacket(g.ctx, packet) - }) + err := g.sendQueue.resend() + if err != nil { + return err + } + + // After resending the queue, we reset the resend ticker. + // This is so that we don't immediately resend the queue again, + // if the sendQueue.resend call above took a long time to + // execute. That can happen if the function was awaiting the + // expected ACK for a long time, or times out while awaiting the + // catch up. + g.resendTicker.Reset(g.cfg.resendTimeout) + + // Also drain the resend signal channel, as resendTicker.Reset + // doesn't drain the channel if the ticker ticked during the + // sendQueue.resend() call above. + select { + case <-g.resendTicker.C: + default: + } + + return nil } for { @@ -478,6 +514,8 @@ func (g *GoBackNConn) receivePacketsForever() error { // nolint:gocyclo g.pongTicker.Pause() } + g.resendTicker.Reset(g.cfg.resendTimeout) + switch m := msg.(type) { case *PacketData: switch m.Seq == g.recvSeq { @@ -526,9 +564,19 @@ func (g *GoBackNConn) receivePacketsForever() error { // nolint:gocyclo // If we recently sent a NACK for the same // sequence number then back off. - if lastNackSeq == g.recvSeq && - time.Since(lastNackTime) < - g.cfg.resendTimeout { + // We wait 2 times the resendTimeout before + // sending a new nack, as this case is likely + // hit if the sender is currently resending + // the queue, and therefore the threads that + // are resending the queue is likely busy with + // the resend, and therefore won't react to the + // NACK we send here in time. + sinceSent := time.Since(lastNackTime) + recentlySent := sinceSent < + g.cfg.resendTimeout*2 + + if lastNackSeq == g.recvSeq && recentlySent { + g.log.Tracef("Recently sent NACK") continue } @@ -552,8 +600,6 @@ func (g *GoBackNConn) receivePacketsForever() error { // nolint:gocyclo case *PacketACK: gotValidACK := g.sendQueue.processACK(m.Seq) if gotValidACK { - g.resendTicker.Reset(g.cfg.resendTimeout) - // Send a signal to indicate that new // ACKs have been received. select { @@ -569,15 +615,12 @@ func (g *GoBackNConn) receivePacketsForever() error { // nolint:gocyclo // sent was dropped, or maybe we sent a duplicate // message. The NACK message contains the sequence // number that the receiver was expecting. - inQueue, bumped := g.sendQueue.processNACK(m.Seq) - - // If the NACK sequence number is not in our queue - // then we ignore it. We must have received the ACK - // for the sequence number in the meantime. - if !inQueue { - g.log.Tracef("NACK seq %d is not in the "+ - "queue. Ignoring", m.Seq) + shouldResend, bumped := g.sendQueue.processNACK(m.Seq) + // If we don't need to resend the queue after processing + // the NACK, we can continue without sending the resend + // signal. + if !shouldResend { continue } @@ -606,6 +649,10 @@ func (g *GoBackNConn) receivePacketsForever() error { // nolint:gocyclo close(g.remoteClosed) + if g.cfg.onFIN != nil { + g.cfg.onFIN() + } + return errTransportClosing default: diff --git a/gbn/options.go b/gbn/options.go index 2c1e8df..4adcb3a 100644 --- a/gbn/options.go +++ b/gbn/options.go @@ -43,3 +43,11 @@ func WithKeepalivePing(ping, pong time.Duration) Option { conn.pongTime = pong } } + +// WithOnFIN is used to set the onFIN callback that will be called once a FIN +// packet has been received and processed. +func WithOnFIN(fn func()) Option { + return func(conn *config) { + conn.onFIN = fn + } +} diff --git a/gbn/queue.go b/gbn/queue.go index 1f2aa44..99954ba 100644 --- a/gbn/queue.go +++ b/gbn/queue.go @@ -7,13 +7,7 @@ import ( "github.com/btcsuite/btclog" ) -// queue is a fixed size queue with a sliding window that has a base and a top -// modulo s. -type queue struct { - // content is the current content of the queue. This is always a slice - // of length s but can contain nil elements if the queue isn't full. - content []*PacketData - +type queueCfg struct { // s is the maximum sequence number used to label packets. Packets // are labelled with incrementing sequence numbers modulo s. // s must be strictly larger than the window size, n. This @@ -23,6 +17,22 @@ type queue struct { // no way to tell. s uint8 + timeout time.Duration + + log btclog.Logger + + sendPkt func(packet *PacketData) error +} + +// queue is a fixed size queue with a sliding window that has a base and a top +// modulo s. +type queue struct { + cfg *queueCfg + + // content is the current content of the queue. This is always a slice + // of length s but can contain nil elements if the queue isn't full. + content []*PacketData + // sequenceBase keeps track of the base of the send window and so // represents the next ack that we expect from the receiver. The // maximum value of sequenceBase is s. @@ -41,27 +51,32 @@ type queue struct { // topMtx is used to guard sequenceTop. topMtx sync.RWMutex - lastResend time.Time - handshakeTimeout time.Duration + syncer *syncer - log btclog.Logger + lastResend time.Time + + quit chan struct{} } // newQueue creates a new queue. -func newQueue(s uint8, handshakeTimeout time.Duration, - logger btclog.Logger) *queue { - - log := log - if logger != nil { - log = logger +func newQueue(cfg *queueCfg) *queue { + if cfg.log == nil { + cfg.log = log } - return &queue{ - content: make([]*PacketData, s), - s: s, - handshakeTimeout: handshakeTimeout, - log: log, + q := &queue{ + cfg: cfg, + content: make([]*PacketData, cfg.s), + quit: make(chan struct{}), } + + q.syncer = newSyncer(cfg.s, cfg.log, cfg.timeout, q.quit) + + return q +} + +func (q *queue) stop() { + close(q.quit) } // size is used to calculate the current sender queueSize. @@ -76,7 +91,7 @@ func (q *queue) size() uint8 { return q.sequenceTop - q.sequenceBase } - return q.sequenceTop + (q.s - q.sequenceBase) + return q.sequenceTop + (q.cfg.s - q.sequenceBase) } // addPacket adds a new packet to the queue. @@ -86,13 +101,15 @@ func (q *queue) addPacket(packet *PacketData) { packet.Seq = q.sequenceTop q.content[q.sequenceTop] = packet - q.sequenceTop = (q.sequenceTop + 1) % q.s + q.sequenceTop = (q.sequenceTop + 1) % q.cfg.s } -// resend invokes the callback for each packet that needs to be re-sent. -func (q *queue) resend(cb func(packet *PacketData) error) error { - if time.Since(q.lastResend) < q.handshakeTimeout { - q.log.Tracef("Resent the queue recently.") +// resend resends the current contents of the queue. It allows some time for the +// two parties to be seen as synced; this may fail in which case the caller is +// expected to call resend again. +func (q *queue) resend() error { + if time.Since(q.lastResend) < q.cfg.timeout { + q.cfg.log.Tracef("Resent the queue recently.") return nil } @@ -115,33 +132,43 @@ func (q *queue) resend(cb func(packet *PacketData) error) error { return nil } - q.log.Tracef("Resending the queue") + // Prepare the queue for awaiting the resend catch up. + q.syncer.initResendUpTo(top) + + q.cfg.log.Tracef("Resending the queue") for base != top { packet := q.content[base] - if err := cb(packet); err != nil { + if err := q.cfg.sendPkt(packet); err != nil { return err } - base = (base + 1) % q.s - q.log.Tracef("Resent %d", packet.Seq) + base = (base + 1) % q.cfg.s + + q.cfg.log.Tracef("Resent %d", packet.Seq) } + // Then wait until we know that both parties are in sync. + q.syncer.waitForSync() + return nil } -// processACK processes an incoming ACK of a given sequence number. +// processACK processes an incoming ACK of a given sequence number. The function +// returns true if the passed seq is an ACK for a packet we have sent but not +// yet received an ACK for. func (q *queue) processACK(seq uint8) bool { - // If our queue is empty, an ACK should not have any effect. if q.size() == 0 { - q.log.Tracef("Received ack %d, but queue is empty. Ignoring.", - seq) + q.cfg.log.Tracef("Received ack %d, but queue is empty. "+ + "Ignoring.", seq) return false } + q.syncer.processACK(seq) + q.baseMtx.Lock() defer q.baseMtx.Unlock() @@ -150,9 +177,9 @@ func (q *queue) processACK(seq uint8) bool { // equal to the one we were expecting. So we increase our base // accordingly and send a signal to indicate that the queue size // has decreased. - q.log.Tracef("Received correct ack %d", seq) + q.cfg.log.Tracef("Received correct ack %d", seq) - q.sequenceBase = (q.sequenceBase + 1) % q.s + q.sequenceBase = (q.sequenceBase + 1) % q.cfg.s // We did receive an ACK. return true @@ -162,7 +189,8 @@ func (q *queue) processACK(seq uint8) bool { // This could be a duplicate ACK before or it could be that we just // missed the ACK for the current base and this is actually an ACK for // another packet in the queue. - q.log.Tracef("Received wrong ack %d, expected %d", seq, q.sequenceBase) + q.cfg.log.Tracef("Received wrong ack %d, expected %d", seq, + q.sequenceBase) q.topMtx.RLock() defer q.topMtx.RUnlock() @@ -171,9 +199,10 @@ func (q *queue) processACK(seq uint8) bool { // just missed a previous ACK. We can bump the base to be equal to this // sequence number. if containsSequence(q.sequenceBase, q.sequenceTop, seq) { - q.log.Tracef("Sequence %d is in the queue. Bump the base.", seq) + q.cfg.log.Tracef("Sequence %d is in the queue. Bump the base.", + seq) - q.sequenceBase = (seq + 1) % q.s + q.sequenceBase = (seq + 1) % q.cfg.s // We did receive an ACK. return true @@ -184,6 +213,12 @@ func (q *queue) processACK(seq uint8) bool { } // processNACK processes an incoming NACK of a given sequence number. +// The function returns two booleans. The first boolean is set to true if we +// should resend the queue after processing the NACK. The second boolean is set +// to true if the NACK sequence number is a packet in the queue which isn't the +// queue base, and we therefore don't need to resend any of the packets before +// the NACK sequence number. This equivalent to receiving the ACKs for the +// packets before the NACK sequence number. func (q *queue) processNACK(seq uint8) (bool, bool) { q.baseMtx.Lock() defer q.baseMtx.Unlock() @@ -191,19 +226,27 @@ func (q *queue) processNACK(seq uint8) (bool, bool) { q.topMtx.RLock() defer q.topMtx.RUnlock() - q.log.Tracef("Received NACK %d", seq) + q.cfg.log.Tracef("Received NACK %d", seq) + + q.syncer.processNACK(seq) // If the NACK is the same as sequenceTop, it probably means that queue - // was sent successfully, but we just missed the necessary ACKs. So we - // can empty the queue here by bumping the base and we dont need to + // was sent successfully, but due to latency we timed out and resent the + // queue before we received the ACKs for the sent packages. + // Alternatively, we might have just missed the necessary ACKs. So we + // can empty the queue here by bumping the base and we don't need to // trigger a resend. if seq == q.sequenceTop { q.sequenceBase = q.sequenceTop - return true, false + + return false, true } // Is the NACKed sequence even in our queue? if !containsSequence(q.sequenceBase, q.sequenceTop, seq) { + q.cfg.log.Tracef("NACK seq %d is not in the queue. Ignoring.", + seq) + return false, false } diff --git a/gbn/queue_test.go b/gbn/queue_test.go index fbcace1..be8e53f 100644 --- a/gbn/queue_test.go +++ b/gbn/queue_test.go @@ -1,13 +1,16 @@ package gbn import ( + "sync" "testing" + "time" + "github.com/lightningnetwork/lnd/lntest/wait" "github.com/stretchr/testify/require" ) func TestQueueSize(t *testing.T) { - q := newQueue(4, 0, nil) + q := newQueue(&queueCfg{s: 4}) require.Equal(t, uint8(0), q.size()) @@ -19,3 +22,131 @@ func TestQueueSize(t *testing.T) { q.sequenceTop = 2 require.Equal(t, uint8(3), q.size()) } + +// TestQueueResend tests that the queue resend functionality works as expected. +// It specifically tests that we actually resend packets, and await the expected +// durations for cases when we resend and 1) don't receive the expected +// ACK/NACK, 2) receive the expected ACK, and 3) receive the expected NACK. +func TestQueueResend(t *testing.T) { + t.Parallel() + + resentPackets := make(map[uint8]struct{}) + queueTimeout := time.Second * 1 + + cfg := &queueCfg{ + s: 5, + timeout: queueTimeout, + sendPkt: func(packet *PacketData) error { + resentPackets[packet.Seq] = struct{}{} + + return nil + }, + } + q := newQueue(cfg) + + pkt1 := &PacketData{Seq: 1} + pkt2 := &PacketData{Seq: 2} + pkt3 := &PacketData{Seq: 3} + + q.addPacket(pkt1) + + // First test that we shouldn't resend if the timeout hasn't passed. + q.lastResend = time.Now() + + err := q.resend() + require.NoError(t, err) + + require.Empty(t, resentPackets) + + // Secondly, let's test that we do resend if the timeout has passed, and + // that we then start a sync once we've resent the packet. + q.lastResend = time.Now().Add(-queueTimeout * 2) + + // Let's first test the syncing scenario where we don't receive + // the expected ACK/NACK for the resent packet. This should trigger a + // timeout of the syncing, which should be the + // queueTimeout * awaitingTimeoutMultiplier. + startTime := time.Now() + + var wg sync.WaitGroup + resend(t, q, &wg) + + wg.Wait() + + // Check that the resend took at least the + // queueTimeout * awaitingTimeoutMultiplier for the syncing to + // complete, and that we actually resent the packet. + require.GreaterOrEqual( + t, time.Since(startTime), + queueTimeout*awaitingTimeoutMultiplier, + ) + require.Contains(t, resentPackets, pkt1.Seq) + + // Now let's test the syncing scenario where we do receive the + // expected ACK for the resent packet. This should trigger a + // queue.proceedAfterTime call, which should finish the syncing + // after the queueTimeout. + q.lastResend = time.Now().Add(-queueTimeout * 2) + + q.addPacket(pkt2) + + startTime = time.Now() + + resend(t, q, &wg) + + // Simulate that we receive the expected ACK for the resent packet. + q.processACK(pkt2.Seq) + + wg.Wait() + + // Now check that the resend took at least the queueTimeout for the + // syncing to complete, and that we actually resent the packet. + require.GreaterOrEqual(t, time.Since(startTime), queueTimeout) + require.LessOrEqual(t, time.Since(startTime), queueTimeout*2) + require.Contains(t, resentPackets, pkt2.Seq) + + // Finally, let's test the syncing scenario where we do receive the + // expected NACK for the resent packet. This make the syncing + // complete immediately. + q.lastResend = time.Now().Add(-queueTimeout * 2) + + q.addPacket(pkt3) + + startTime = time.Now() + resend(t, q, &wg) + + // Simulate that we receive the expected NACK for the resent packet. + q.processNACK(pkt3.Seq + 1) + + wg.Wait() + + // Finally let's check that we didn't await any timeout now, and that + // we actually resent the packet. + require.Less(t, time.Since(startTime), queueTimeout) + require.Contains(t, resentPackets, pkt3.Seq) +} + +// resend is a helper function that resends packets in a goroutine, and notifies +// the WaitGroup when the resend + syncing has completed. +// The function will block until the resend has actually started. +func resend(t *testing.T, q *queue, wg *sync.WaitGroup) { + t.Helper() + + wg.Add(1) + + // This will trigger a sync, so we launch the resend in a + // goroutine. + go func() { + err := q.resend() + require.NoError(t, err) + + wg.Done() + }() + + // We also ensure that the above goroutine is has started the resend + // before this function returns. + err := wait.Predicate(func() bool { + return q.syncer.getState() == syncStateResending + }, time.Second) + require.NoError(t, err) +} diff --git a/gbn/syncer.go b/gbn/syncer.go new file mode 100644 index 0000000..2289023 --- /dev/null +++ b/gbn/syncer.go @@ -0,0 +1,306 @@ +package gbn + +import ( + "sync" + "time" + + "github.com/btcsuite/btclog" +) + +const ( + // awaitingTimeoutMultiplier defines the multiplier we use when + // multiplying the resend timeout, resulting in the duration we wait for + // the sync to be complete before timing out. + // We set this to 3X the resend timeout. The reason we wait exactly 3X + // the resend timeout is that we expect that the max time that the + // correct behavior would take, would be: + // * 1X the resendTimeout for the time it would take for the party + // respond with an ACK for the last packet in the resend queue, i.e. the + // expectedACK. + // * 1X the resendTimeout while waiting in proceedAfterTime before + // completing the sync. + // * 1X extra resendTimeout as buffer, to ensure that we have enough + // time to process the ACKS/NACKS by other party + some extra margin. + awaitingTimeoutMultiplier = 3 +) + +type syncState uint8 + +const ( + // syncStateIdle is the state representing that the syncer is idle and + // has not yet initiated a resend sync. + syncStateIdle syncState = iota + + // syncStateResending is the state representing that the syncer has + // initiated a resend sync, and is awaiting that the sync is completed. + syncStateResending +) + +// syncer is used to ensure that both the sender and the receiver are in sync +// before the waitForSync function is completed. This is done by waiting until +// we receive either the expected ACK or NACK after resending the queue. +// +// To understand why we need to wait for the expected ACK/NACK after resending +// the queue, it ensures that we don't end up in a situation where we resend the +// queue over and over again due to latency and delayed NACKs by the other +// party. +// +// Consider the following scenario: +// 1. +// Alice sends packets 1, 2, 3 & 4 to Bob. +// 2. +// Bob receives packets 1, 2, 3 & 4, and sends back the respective ACKs. +// 3. +// Alice receives ACKs for packets 1 & 2, but due to latency the ACKs for +// packets 3 & 4 are delayed and aren't received until Alice resend timeout +// has passed, which leads to Alice resending packets 3 & 4. Alice will after +// that receive the delayed ACKs for packets 3 & 4, but will consider that as +// the ACKs for the resent packets, and not the original packets which they were +// actually sent for. If we didn't wait after resending the queue, Alice would +// then proceed to send more packets (5 & 6). +// 4. +// When Bob receives the resent packets 3 & 4, Bob will respond with NACK 5. Due +// to latency, the packets 5 & 6 that Alice sent in step (3) above will then be +// received by Bob, and be processed as the correct response to the NACK 5. Bob +// will after that await packet 7. +// 5. +// Alice will receive the NACK 5, and now resend packets 5 & 6. But as Bob is +// now awaiting packet 7, this send will lead to a NACK 7. But due to latency, +// if Alice doesn't wait resending the queue, Alice will proceed to send new +// packet(s) before receiving the NACK 7. +// 6. +// This resend loop would continue indefinitely, so we need to ensure that Alice +// waits after she has resent the queue, to ensure that she doesn't proceed to +// send new packets before she is sure that both parties are in sync. +// +// To ensure that we are in sync, after we have resent the queue, we will await +// that we either: +// 1. Receive a NACK for the sequence number succeeding the last packet in the +// resent queue i.e. in step (3) above, that would be NACK 5. +// OR +// 2. Receive an ACK for the last packet in the resent queue i.e. in step (3) +// above, that would be ACK 4. After we receive the expected ACK, we will then +// wait for the duration of the resend timeout before continuing. The reason why +// we wait for the resend timeout before continuing, is that the ACKs we are +// getting after a resend, could be delayed ACKs for the original packets we +// sent, and not ACKs for the resent packets. In step (3) above, the ACKs for +// packets 3 & 4 that Alice received were delayed ACKs for the original packets. +// If Alice would have immediately continued to send new packets (5 & 6) after +// receiving the ACK 4, she would have then received the NACK 5 from Bob which +// was the actual response to the resent queue. But as Alice had already +// continued to send packets 5 & 6 when receiving the NACK 5, the resend queue +// response to that NACK would cause the resend loop to continue indefinitely. +// OR +// 3. If neither of condition 1 or 2 above is met within 3X the resend timeout, +// we will time out and consider the sync to be completed. See the docs for +// awaitingTimeoutMultiplier for more details on why we wait 3X the resend +// timeout. +// +// When either of the 3 conditions above are met, we will consider both parties +// to be in sync. +type syncer struct { + s uint8 + log btclog.Logger + timeout time.Duration + + state syncState + + // expectedACK defines the sequence number for the last packet in the + // resend queue. If we receive an ACK for this sequence number while + // waiting to sync, we wait for the duration of the resend timeout, + // and then proceed to send new packets, unless we receive the + // expectedNACK during the wait time. If that happens, we will proceed + // to send new packets as soon as we have processed the NACK. + expectedACK uint8 + + // expectedNACK is set to the sequence number that follows the last item + // in resend queue, when a sync is initiated. In case we get a NACK with + // this sequence number when waiting to sync, we'd consider the sync to + // be completed and we can proceed to send new packets. + expectedNACK uint8 + + // cancel is used to mark that the sync has been completed. + cancel chan struct{} + + quit chan struct{} + mu sync.Mutex +} + +// newSyncer creates a new syncer instance. +func newSyncer(s uint8, prefixLogger btclog.Logger, timeout time.Duration, + quit chan struct{}) *syncer { + + if prefixLogger == nil { + prefixLogger = log + } + + return &syncer{ + s: s, + log: prefixLogger, + timeout: timeout, + state: syncStateIdle, + cancel: make(chan struct{}), + quit: quit, + } +} + +// reset resets the syncer state to idle and marks the sync as completed. +func (c *syncer) reset() { + c.mu.Lock() + defer c.mu.Unlock() + + c.resetUnsafe() +} + +// resetUnsafe resets the syncer state to idle and marks the sync as completed. +// +// NOTE: when calling this function, the caller must hold the syncer mutex. +func (c *syncer) resetUnsafe() { + c.state = syncStateIdle + + // Cancel any pending sync. + select { + case c.cancel <- struct{}{}: + default: + } +} + +// initResendUpTo initializes the syncer to the resending state, and will after +// this call be ready to wait for the sync to be completed when calling the +// waitForSync function. +// The top argument defines the sequence number of the next packet to be sent +// after resending the queue. +func (c *syncer) initResendUpTo(top uint8) { + c.mu.Lock() + defer c.mu.Unlock() + + c.state = syncStateResending + + // Drain the cancel channel, to reinitialize it for the new sync. + select { + case <-c.cancel: + default: + } + + c.expectedACK = (c.s + top - 1) % c.s + c.expectedNACK = top + + c.log.Tracef("Set expectedACK to %d & expectedNACK to %d", + c.expectedACK, c.expectedNACK) +} + +// getState returns the current state of the syncer. +func (c *syncer) getState() syncState { + c.mu.Lock() + defer c.mu.Unlock() + + return c.state +} + +// waitForSync waits for the sync to be completed. The sync is completed when we +// receive either the expectedNACK, the expectedACK + resend timeout has passed, +// or when timing out. +func (c *syncer) waitForSync() { + c.log.Tracef("Awaiting sync after resending the queue") + + select { + case <-c.quit: + return + + case <-c.cancel: + c.log.Tracef("sync canceled or reset") + + case <-time.After(c.timeout * awaitingTimeoutMultiplier): + c.log.Tracef("Timed out while waiting for sync") + } + + c.reset() +} + +// processACK marks the sync as completed if the passed sequence number matches +// the expectedACK, after the resend timeout has passed. +// If we are not resending or waiting after a resend, this is a no-op. +func (c *syncer) processACK(seq uint8) { + c.mu.Lock() + defer c.mu.Unlock() + + // If we are not resending or waiting after a resend, just swallow the + // ACK. + if c.state != syncStateResending { + return + } + + // Else, if we are waiting but this is not the ack we are waiting for, + // just swallow it. + if seq != c.expectedACK { + return + } + + c.log.Tracef("Got expected ACK") + + // We start the proceedAfterTime function in a goroutine, as we + // don't want to block the processing of other NACKs/ACKs while + // we're waiting for the resend timeout to expire. + go c.proceedAfterTime() +} + +// processNACK marks the sync as completed if the passed sequence number matches +// the expectedNACK. +// If we are not resending or waiting after a resend, this is a no-op. +func (c *syncer) processNACK(seq uint8) { + c.mu.Lock() + defer c.mu.Unlock() + + // If we are not resending or waiting after a resend, just swallow the + // NACK. + if c.state != syncStateResending { + return + } + + // Else, if we are waiting but this is not the NACK we are waiting for, + // just swallow it. + if seq != c.expectedNACK { + return + } + + c.log.Tracef("Got expected NACK") + + c.resetUnsafe() +} + +// proceedAfterTime will wait for the resendTimeout and then complete the sync, +// if we haven't completed the sync yet by receiving the expectedNACK. +func (c *syncer) proceedAfterTime() { + // We await for the duration of the resendTimeout before completing the + // sync, as that's the time we'd expect it to take for the other party + // to respond with a NACK, if the resent last packet in the + // queue would lead to a NACK. If we receive the expectedNACK + // before the timeout, the cancel channel will be sent over, and we can + // stop the execution early. + select { + case <-c.quit: + return + + case <-c.cancel: + c.log.Tracef("sync succeeded or was reset") + + // As we can't be sure that waitForSync cancel listener was + // triggered before this one, we send over the cancel channel + // again, to make sure that both listeners are triggered. + c.reset() + + return + + case <-time.After(c.timeout): + c.mu.Lock() + defer c.mu.Unlock() + + if c.state != syncStateResending { + return + } + + c.log.Tracef("Completing sync after expectedACK timeout") + + c.resetUnsafe() + } +} diff --git a/gbn/syncer_test.go b/gbn/syncer_test.go new file mode 100644 index 0000000..eaca3ff --- /dev/null +++ b/gbn/syncer_test.go @@ -0,0 +1,110 @@ +package gbn + +import ( + "sync" + "testing" + "time" + + "github.com/lightningnetwork/lnd/lntest/wait" + "github.com/stretchr/testify/require" +) + +// TestSyncer tests that the syncer functionality works as expected. It +// specifically tests that we wait the expected durations for cases when we +// initiate resends and 1) don't receive the expected ACK/NACK, 2) receive the +// expected ACK, and 3) receive the expected NACK. +func TestSyncer(t *testing.T) { + t.Parallel() + + syncTimeout := time.Second * 1 + expectedNACK := uint8(3) + + syncer := newSyncer(5, nil, syncTimeout, make(chan struct{})) + + // Let's first test the scenario where we don't receive the expected + // ACK/NACK after initiating the resend. This should trigger a timeout + // of the sync, which should be the: + // syncTimeout * awaitingTimeoutMultiplier. + startTime := time.Now() + + var wg sync.WaitGroup + initResend(t, syncer, &wg, expectedNACK) + + wg.Wait() + + // Check that the syncing took at least the + // syncTimeout * awaitingTimeoutMultiplier to complete. + require.GreaterOrEqual( + t, time.Since(startTime), + syncTimeout*awaitingTimeoutMultiplier, + ) + + // Now let's test the scenario where we do receive the expected ACK for + // the when awaiting the sync. This should trigger the + // syncer.proceedAfterTime call, which should finish complete the sync + // after the syncTimeout. + startTime = time.Now() + + initResend(t, syncer, &wg, expectedNACK) + + // Simulate that we receive the expected ACK. + simulateACKProcessing(t, syncer, 1, expectedNACK-1) + + wg.Wait() + + // Now check that the sync took at least the syncTimeout to complete. + require.GreaterOrEqual(t, time.Since(startTime), syncTimeout) + require.LessOrEqual(t, time.Since(startTime), syncTimeout*2) + + // Finally, let's test the scenario where we do receive the expected + // NACK when syncing. This completes the sync immediately. + startTime = time.Now() + + initResend(t, syncer, &wg, expectedNACK) + + // Simulate that we receive the expected NACK. + syncer.processNACK(expectedNACK) + + wg.Wait() + + // Finally let's check that we didn't exit on the timeout in this case. + require.Less(t, time.Since(startTime), syncTimeout) +} + +// simulateACKProcessing is a helper function that simulates the processing of +// ACKs for the given range of sequence numbers. +func simulateACKProcessing(t *testing.T, syncer *syncer, base, last uint8) { + t.Helper() + + for i := base; i <= last; i++ { + syncer.processACK(i) + } +} + +// initResend is a helper function that triggers an initResendUpTo for the given +// top and then waits for the sync to complete in a goroutine, and notifies +// the WaitGroup when the sync has completed. +// The function will block until the initResendUpTo function has executed. +func initResend(t *testing.T, syncer *syncer, wg *sync.WaitGroup, top uint8) { + t.Helper() + + require.Equal(t, syncStateIdle, syncer.getState()) + + wg.Add(1) + + // This will trigger a resend catchup, so we launch the resend in a + // goroutine. + go func() { + syncer.initResendUpTo(top) + syncer.waitForSync() + + wg.Done() + }() + + // We also ensure that the above goroutine has executed the + // initResendUpTo function before this function returns. + err := wait.Predicate(func() bool { + return syncer.getState() == syncStateResending + }, time.Second) + require.NoError(t, err) +} diff --git a/mailbox/client_conn.go b/mailbox/client_conn.go index 6355a94..445daa3 100644 --- a/mailbox/client_conn.go +++ b/mailbox/client_conn.go @@ -157,20 +157,32 @@ func NewClientConn(ctx context.Context, sid [64]byte, serverHost string, } c := &ClientConn{ - transport: transport, - gbnOptions: []gbn.Option{ - gbn.WithTimeout(gbnTimeout), - gbn.WithHandshakeTimeout(gbnHandshakeTimeout), - gbn.WithKeepalivePing( - gbnClientPingTimeout, gbnPongTimeout, - ), - }, + transport: transport, status: ClientStatusNotConnected, onNewStatus: onNewStatus, quit: make(chan struct{}), cancel: cancel, log: logger, } + + c.gbnOptions = []gbn.Option{ + gbn.WithTimeout(gbnTimeout), + gbn.WithHandshakeTimeout(gbnHandshakeTimeout), + gbn.WithKeepalivePing( + gbnClientPingTimeout, gbnPongTimeout, + ), + gbn.WithOnFIN(func() { + // We force the connection to set a new status after + // processing a FIN packet, as in rare occasions the + // corresponding server may have time to close the + // connection before we've already processed the sent + // FIN packet by the server. In that case, if we didn't + // force a new status, the client would never mark the + // connection as status ClientStatusSessionNotFound. + c.setStatus(ClientStatusSessionNotFound) + }), + } + c.connKit = &connKit{ ctx: ctxc, impl: c,