Skip to content

Commit

Permalink
gbn: add client side resend loop protection
Browse files Browse the repository at this point in the history
  • Loading branch information
ViktorTigerstrom committed Nov 7, 2023
1 parent 9db0d53 commit 6b38b9e
Show file tree
Hide file tree
Showing 2 changed files with 202 additions and 24 deletions.
175 changes: 163 additions & 12 deletions gbn/gbn_conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,14 @@ type GoBackNConn struct {
recvTimeout time.Duration
recvTimeoutMu sync.RWMutex

awaitedACK uint8
awaitedNACK uint8
awaitingCatchUp bool
awaitingCatchUpMu sync.RWMutex

awaitedACKSignal chan struct{}
awaitedNACKSignal chan struct{}

isServer bool

// handshakeTimeout is the time after which the server or client
Expand Down Expand Up @@ -134,6 +142,8 @@ func newGoBackNConn(ctx context.Context, sendFunc sendBytesFunc,
receivedACKSignal: make(chan struct{}),
resendSignal: make(chan struct{}, 1),
remoteClosed: make(chan struct{}),
awaitedACKSignal: make(chan struct{}, 1),
awaitedNACKSignal: make(chan struct{}, 1),
ctx: ctxc,
cancel: cancel,
quit: make(chan struct{}),
Expand Down Expand Up @@ -378,20 +388,105 @@ func (g *GoBackNConn) sendPacket(ctx context.Context, msg Message) error {
return nil
}

// resendQueue re-sends the current contents of the queue, and awaits that we
// either receive the expected ACK or NACK after resending the queue before
// returning.
func (g *GoBackNConn) resendQueue() error {
g.awaitingCatchUpMu.Lock()

base, top := g.sendQueue.getSequenceTop()

if !g.sendQueue.shouldResend() {
g.awaitingCatchUpMu.Unlock()
log.Tracef("Breaking due to shouldn't resend")
return nil
}

if g.sendQueue.isPayloadPackets(base, top) {
g.awaitedACK = (g.sendQueue.s + top - 1) % g.sendQueue.s
g.awaitedNACK = top
log.Tracef("Set awaitedACK to %d & awaitedNACK to %d",
g.awaitedACK, g.awaitedNACK)

g.awaitingCatchUp = true
} else {
log.Tracef("Won't catch up due to ping packet(s) only")
g.awaitingCatchUp = false
}

err := g.sendQueue.resend(
base, top, func(packet *PacketData) error {
return g.sendPacket(g.ctx, packet)
},
)

// We hold the lock for the duration of the resend to ensure
// that we don't process the delayed ACKs for the packets we
// are resending, during the resend. If that would happen, we
// would start the "proceedAfterTime" timeout while still
// resending packets, which could mean that the NACK that the
// resend will trigger, might be received after the timeout
// has passed. That would cause the resend loop to trigger once
// more.
g.awaitingCatchUpMu.Unlock()

if err != nil {
return err
}

if !g.awaitingCatchUp {
return nil
}

ticker := time.NewTimer(g.resendTimeout * 7)
log.Tracef("Awaiting catchup after resending the queue")

catchupLoop:
for {
select {
case <-g.quit:
return nil
case <-g.awaitedACKSignal:
log.Tracef("Got awaitedACKSignal")
break catchupLoop
case <-g.awaitedNACKSignal:
log.Tracef("Got awaitedNACKSignal")
break catchupLoop
case <-ticker.C:
log.Tracef("Timed out while awaiting catchup")

g.awaitingCatchUpMu.Lock()
g.awaitingCatchUp = false

// If we time out, we need to also reset the
// channels, as they could have been sent over
// when we waiting to take the awaitingCatchUpMu
// lock above after the timeout. If we don't
// reset them, this select case would catch
// the sent signal, next time we resend the
// queue.
g.awaitedACKSignal = make(chan struct{}, 1)
g.awaitedNACKSignal = make(chan struct{}, 1)

g.awaitingCatchUpMu.Unlock()

break catchupLoop
default:
continue
}
}
ticker.Stop()

return nil
}

// sendPacketsForever manages the resending logic. It keeps a cache of up to
// N packets and manages the resending of packets if acks are not received for
// them or if NACKs are received. It reads new data from sendDataChan only
// when there is space in the queue.
//
// This function must be called in a go routine.
func (g *GoBackNConn) sendPacketsForever() error {
// resendQueue re-sends the current contents of the queue.
resendQueue := func() error {
return g.sendQueue.resend(func(packet *PacketData) error {
return g.sendPacket(g.ctx, packet)
})
}

for {
// The queue is not empty. If we receive a resend signal
// or if the resend timeout passes then we resend the
Expand All @@ -403,13 +498,13 @@ func (g *GoBackNConn) sendPacketsForever() error {
return nil

case <-g.resendSignal:
if err := resendQueue(); err != nil {
if err := g.resendQueue(); err != nil {
return err
}
continue

case <-g.resendTicker.C:
if err := resendQueue(); err != nil {
if err := g.resendQueue(); err != nil {
return err
}
continue
Expand Down Expand Up @@ -459,11 +554,11 @@ func (g *GoBackNConn) sendPacketsForever() error {
case <-g.receivedACKSignal:
break
case <-g.resendSignal:
if err := resendQueue(); err != nil {
if err := g.resendQueue(); err != nil {
return err
}
case <-g.resendTicker.C:
if err := resendQueue(); err != nil {
if err := g.resendQueue(); err != nil {
return err
}
}
Expand Down Expand Up @@ -557,7 +652,7 @@ func (g *GoBackNConn) receivePacketsForever() error { // nolint:gocyclo
// sequence number then back off.
if lastNackSeq == g.recvSeq &&
time.Since(lastNackTime) < g.resendTimeout {

log.Tracef("%d recently sent NACK for %d", g.seqTest2, m.Seq)
continue
}

Expand All @@ -582,6 +677,13 @@ func (g *GoBackNConn) receivePacketsForever() error { // nolint:gocyclo
if gotValidACK {
g.resendTicker.Reset(g.resendTimeout)

g.awaitingCatchUpMu.RLock()
if m.Seq == g.awaitedACK && g.awaitingCatchUp {
log.Tracef("Got awaited ACK")
g.proceedAfterTime(g.awaitedACK)
}
g.awaitingCatchUpMu.RUnlock()

// Send a signal to indicate that new
// ACKs have been received.
select {
Expand Down Expand Up @@ -618,6 +720,23 @@ func (g *GoBackNConn) receivePacketsForever() error { // nolint:gocyclo
}
}

g.awaitingCatchUpMu.Lock()
if g.awaitingCatchUp && m.Seq == g.awaitedNACK {
log.Tracef("Sending awaitedNACKSignal")
g.awaitedNACKSignal <- struct{}{}

g.awaitingCatchUp = false

g.awaitingCatchUpMu.Unlock()

// If we received a NACK we were waiting for, we
// can proceed with sending the next packet.
// Therefore we don't send any resendSignal, as
// this would cause a resend loop.
continue
}
g.awaitingCatchUpMu.Unlock()

log.Tracef("Sending a resend signal (isServer=%v)",
g.isServer)

Expand All @@ -644,3 +763,35 @@ func (g *GoBackNConn) receivePacketsForever() error { // nolint:gocyclo
}
}
}

func (g *GoBackNConn) proceedAfterTime(awaitedACK uint8) {
cb := func() {
log.Tracef("Executing proceedAfterTime")
g.awaitingCatchUpMu.Lock()

// We want to ensure that this function only sets awaitingCatchUp
// to false if we're still waiting for the awaitingCatchUp, AND
// not being run after the awaitingCatchUp has been disabled and
// then enabled again, which could happen if we received a NACK
// during the time we waited to execute this function, and then
// resent the queue again. Therefore we check that the current
// awaitedACK is the same as the awaitedACK we had before waiting.
_, top := g.sendQueue.getSequenceTop()
currentawaitedACK := (g.sendQueue.s + top - 1) % g.sendQueue.s

if g.awaitingCatchUp && (currentawaitedACK == awaitedACK) {
log.Tracef("Sending awaitedACKSignal")
g.awaitedACKSignal <- struct{}{}

g.awaitingCatchUp = false
} else {
log.Tracef("Ending proceedAfterTime without any "+
"action, awaitingCatchUp=%v, currentawaitedACK=%d, "+
"awaitedACK=%d,",
g.awaitingCatchUp, currentawaitedACK, awaitedACK)
}

g.awaitingCatchUpMu.Unlock()
}
time.AfterFunc(g.resendTimeout*2, cb)
}
51 changes: 39 additions & 12 deletions gbn/queue.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,16 +77,17 @@ func (q *queue) addPacket(packet *PacketData) {
q.sequenceTop = (q.sequenceTop + 1) % q.s
}

// resend invokes the callback for each packet that needs to be re-sent.
func (q *queue) resend(cb func(packet *PacketData) error) error {
// shouldResend returns true if the queue should be resent.
func (q *queue) shouldResend() bool {
if time.Since(q.lastResend) < q.handshakeTimeout {
log.Tracef("Resent the queue recently.")
log.Tracef("%d: Resent the queue recently.")

return nil
return false
}

if q.size() == 0 {
return nil
log.Tracef("%d: 0 queue size")
return false
}

q.lastResend = time.Now()
Expand All @@ -99,9 +100,12 @@ func (q *queue) resend(cb func(packet *PacketData) error) error {
top := q.sequenceTop
q.topMtx.RUnlock()

if base == top {
return nil
}
return base != top
}

// resend invokes the callback for each packet that needs to be re-sent.
func (q *queue) resend(base, top uint8,
cb func(packet *PacketData) error) error {

log.Tracef("Resending the queue")

Expand All @@ -119,13 +123,25 @@ func (q *queue) resend(cb func(packet *PacketData) error) error {
return nil
}

// isPayloadPackets returns true if the packets for the passed packet indexes
// isn't Ping packets.
func (q *queue) isPayloadPackets(base, top uint8) bool {
for base != top {
packet := q.content[base]

if packet.IsPing {
return false
}
base = (base + 1) % q.s
}
return true
}

// processACK processes an incoming ACK of a given sequence number.
func (q *queue) processACK(seq uint8) bool {

// If our queue is empty, an ACK should not have any effect.
if q.size() == 0 {
log.Tracef("Received ack %d, but queue is empty. Ignoring.",
seq)
log.Tracef("Received ack %d, but queue is empty. Ignoring.", seq)
return false
}

Expand All @@ -149,7 +165,7 @@ func (q *queue) processACK(seq uint8) bool {
// This could be a duplicate ACK before or it could be that we just
// missed the ACK for the current base and this is actually an ACK for
// another packet in the queue.
log.Tracef("Received wrong ack %d, expected %d", seq, q.sequenceBase)
log.Tracef("%Received wrong ack %d, expected %d", seq, q.sequenceBase)

q.topMtx.RLock()
defer q.topMtx.RUnlock()
Expand Down Expand Up @@ -206,6 +222,17 @@ func (q *queue) processNACK(seq uint8) (bool, bool) {
return true, bumped
}

// getSequenceTop returns the current sequence top.
func (q *queue) getSequenceTop() (uint8, uint8) {
q.baseMtx.Lock()
defer q.baseMtx.Unlock()

q.topMtx.RLock()
defer q.topMtx.RUnlock()

return q.sequenceBase, q.sequenceTop
}

// containsSequence is used to determine if a number, seq, is between two other
// numbers, base and top, where all the numbers lie in a finite field (modulo
// space) s.
Expand Down

0 comments on commit 6b38b9e

Please sign in to comment.