diff --git a/gbn/config.go b/gbn/config.go new file mode 100644 index 0000000..4d14f7f --- /dev/null +++ b/gbn/config.go @@ -0,0 +1,62 @@ +package gbn + +import "time" + +// config holds the configuration values for an instance of GoBackNConn. +type config struct { + // n is the window size. The sender can send a maximum of n packets + // before requiring an ack from the receiver for the first packet in + // the window. The value of n is chosen by the client during the + // GoBN handshake. + n uint8 + + // s is the maximum sequence number used to label packets. Packets + // are labelled with incrementing sequence numbers modulo s. + // s must be strictly larger than the window size, n. This + // is so that the receiver can tell if the sender is resending the + // previous window (maybe the sender did not receive the acks) or if + // they are sending the next window. If s <= n then there would be + // no way to tell. + s uint8 + + // maxChunkSize is the maximum payload size in bytes allowed per + // message. If the payload to be sent is larger than maxChunkSize then + // the payload will be split between multiple packets. + // If maxChunkSize is zero then it is disabled and data won't be split + // between packets. + maxChunkSize int + + // resendTimeout is the duration that will be waited before resending + // the packets in the current queue. + resendTimeout time.Duration + + // recvFromStream is the function that will be used to acquire the next + // available packet. + recvFromStream recvBytesFunc + + // sendToStream is the function that will be used to send over our next + // packet. + sendToStream sendBytesFunc + + // handshakeTimeout is the time after which the server or client + // will abort and restart the handshake if the expected response is + // not received from the peer. + handshakeTimeout time.Duration + + pingTime time.Duration + pongTime time.Duration +} + +// newConfig constructs a new config struct. +func newConfig(sendFunc sendBytesFunc, recvFunc recvBytesFunc, + n uint8) *config { + + return &config{ + n: n, + s: n + 1, + recvFromStream: recvFunc, + sendToStream: sendFunc, + resendTimeout: defaultResendTimeout, + handshakeTimeout: defaultHandshakeTimeout, + } +} diff --git a/gbn/gbn_client.go b/gbn/gbn_client.go index db750e9..cdfa86f 100644 --- a/gbn/gbn_client.go +++ b/gbn/gbn_client.go @@ -21,13 +21,15 @@ func NewClientConn(ctx context.Context, n uint8, sendFunc sendBytesFunc, math.MaxUint8) } - conn := newGoBackNConn(ctx, sendFunc, receiveFunc, false, n) + cfg := newConfig(sendFunc, receiveFunc, n) // Apply functional options for _, o := range opts { - o(conn) + o(cfg) } + conn := newGoBackNConn(ctx, cfg, "client") + if err := conn.clientHandshake(); err != nil { if err := conn.Close(); err != nil { log.Errorf("error closing gbn ClientConn: %v", err) @@ -76,7 +78,7 @@ func (g *GoBackNConn) clientHandshake() error { case <-recvNext: } - b, err := g.recvFromStream(g.ctx) + b, err := g.cfg.recvFromStream(g.ctx) if err != nil { errChan <- err return @@ -101,7 +103,7 @@ func (g *GoBackNConn) clientHandshake() error { handshake: for { // start Handshake - msg := &PacketSYN{N: g.n} + msg := &PacketSYN{N: g.cfg.n} msgBytes, err := msg.Serialize() if err != nil { return err @@ -109,7 +111,7 @@ handshake: // Send SYN g.log.Debugf("Sending SYN") - if err := g.sendToStream(g.ctx, msgBytes); err != nil { + if err := g.cfg.sendToStream(g.ctx, msgBytes); err != nil { return err } @@ -128,7 +130,7 @@ handshake: var b []byte select { - case <-time.After(g.handshakeTimeout): + case <-time.After(g.cfg.handshakeTimeout): g.log.Debugf("SYN resendTimeout. Resending " + "SYN.") @@ -165,7 +167,7 @@ handshake: g.log.Debugf("Got SYN") - if respSYN.N != g.n { + if respSYN.N != g.cfg.n { return io.EOF } @@ -176,7 +178,7 @@ handshake: return err } - if err := g.sendToStream(g.ctx, synack); err != nil { + if err := g.cfg.sendToStream(g.ctx, synack); err != nil { return err } diff --git a/gbn/gbn_conn.go b/gbn/gbn_conn.go index 5d2a2e6..131f2a3 100644 --- a/gbn/gbn_conn.go +++ b/gbn/gbn_conn.go @@ -33,27 +33,7 @@ type sendBytesFunc func(ctx context.Context, b []byte) error type recvBytesFunc func(ctx context.Context) ([]byte, error) type GoBackNConn struct { - // n is the window size. The sender can send a maximum of n packets - // before requiring an ack from the receiver for the first packet in - // the window. The value of n is chosen by the client during the - // GoBN handshake. - n uint8 - - // s is the maximum sequence number used to label packets. Packets - // are labelled with incrementing sequence numbers modulo s. - // s must be strictly larger than the window size, n. This - // is so that the receiver can tell if the sender is resending the - // previous window (maybe the sender did not receive the acks) or if - // they are sending the next window. If s <= n then there would be - // no way to tell. - s uint8 - - // maxChunkSize is the maximum payload size in bytes allowed per - // message. If the payload to be sent is larger than maxChunkSize then - // the payload will be split between multiple packets. - // If maxChunkSize is zero then it is disabled and data won't be split - // between packets. - maxChunkSize int + cfg *config sendQueue *queue @@ -61,30 +41,17 @@ type GoBackNConn struct { // sequence that we have received. recvSeq uint8 - // resendTimeout is the duration that will be waited before resending - // the packets in the current queue. - resendTimeout time.Duration - resendTicker *time.Ticker - - recvFromStream recvBytesFunc - sendToStream sendBytesFunc + resendTicker *time.Ticker recvDataChan chan *PacketData sendDataChan chan *PacketData - sendTimeout time.Duration - sendTimeoutMu sync.RWMutex - - recvTimeout time.Duration - recvTimeoutMu sync.RWMutex + sendTimeout time.Duration + recvTimeout time.Duration + timeoutsMu sync.RWMutex log btclog.Logger - // handshakeTimeout is the time after which the server or client - // will abort and restart the handshake if the expected response is - // not received from the peer. - handshakeTimeout time.Duration - // receivedACKSignal channel is used to signal that the queue size has // been decreased. receivedACKSignal chan struct{} @@ -94,11 +61,8 @@ type GoBackNConn struct { // that this channel should only be listened on in one place. resendSignal chan struct{} - pingTime time.Duration - pongTime time.Duration pingTicker *IntervalAwareForceTicker pongTicker *IntervalAwareForceTicker - pongWait chan struct{} ctx context.Context //nolint:containedctx cancel func() @@ -116,29 +80,22 @@ type GoBackNConn struct { // newGoBackNConn creates a GoBackNConn instance with all the members which // are common between client and server initialised. -func newGoBackNConn(ctx context.Context, sendFunc sendBytesFunc, - recvFunc recvBytesFunc, isServer bool, n uint8) *GoBackNConn { +func newGoBackNConn(ctx context.Context, cfg *config, + loggerPrefix string) *GoBackNConn { ctxc, cancel := context.WithCancel(ctx) // Construct a new prefixed logger. - identifier := "client" - if isServer { - identifier = "server" - } - prefix := fmt.Sprintf("(%s)", identifier) + prefix := fmt.Sprintf("(%s)", loggerPrefix) plog := build.NewPrefixLog(prefix, log) return &GoBackNConn{ - n: n, - s: n + 1, - resendTimeout: defaultResendTimeout, - recvFromStream: recvFunc, - sendToStream: sendFunc, - recvDataChan: make(chan *PacketData, n), - sendDataChan: make(chan *PacketData), - sendQueue: newQueue(n+1, defaultHandshakeTimeout, plog), - handshakeTimeout: defaultHandshakeTimeout, + cfg: cfg, + recvDataChan: make(chan *PacketData, cfg.n), + sendDataChan: make(chan *PacketData), + sendQueue: newQueue( + cfg.n+1, defaultHandshakeTimeout, plog, + ), recvTimeout: DefaultRecvTimeout, sendTimeout: DefaultSendTimeout, receivedACKSignal: make(chan struct{}), @@ -154,24 +111,24 @@ func newGoBackNConn(ctx context.Context, sendFunc sendBytesFunc, // setN sets the current N to use. This _must_ be set before the handshake is // completed. func (g *GoBackNConn) setN(n uint8) { - g.n = n - g.s = n + 1 + g.cfg.n = n + g.cfg.s = n + 1 g.recvDataChan = make(chan *PacketData, n) g.sendQueue = newQueue(n+1, defaultHandshakeTimeout, g.log) } // SetSendTimeout sets the timeout used in the Send function. func (g *GoBackNConn) SetSendTimeout(timeout time.Duration) { - g.sendTimeoutMu.Lock() - defer g.sendTimeoutMu.Unlock() + g.timeoutsMu.Lock() + defer g.timeoutsMu.Unlock() g.sendTimeout = timeout } // SetRecvTimeout sets the timeout used in the Recv function. func (g *GoBackNConn) SetRecvTimeout(timeout time.Duration) { - g.recvTimeoutMu.Lock() - defer g.recvTimeoutMu.Unlock() + g.timeoutsMu.Lock() + defer g.timeoutsMu.Unlock() g.recvTimeout = timeout } @@ -185,9 +142,9 @@ func (g *GoBackNConn) Send(data []byte) error { default: } - g.sendTimeoutMu.RLock() + g.timeoutsMu.RLock() ticker := time.NewTimer(g.sendTimeout) - g.sendTimeoutMu.RUnlock() + g.timeoutsMu.RUnlock() defer ticker.Stop() sendPacket := func(packet *PacketData) error { @@ -201,7 +158,7 @@ func (g *GoBackNConn) Send(data []byte) error { } } - if g.maxChunkSize == 0 { + if g.cfg.maxChunkSize == 0 { // Splitting is disabled. return sendPacket(&PacketData{ Payload: data, @@ -210,18 +167,21 @@ func (g *GoBackNConn) Send(data []byte) error { } // Splitting is enabled. Split into packets no larger than maxChunkSize. - sentBytes := 0 + var ( + sentBytes = 0 + maxChunk = g.cfg.maxChunkSize + ) for sentBytes < len(data) { packet := &PacketData{} remainingBytes := len(data) - sentBytes - if remainingBytes <= g.maxChunkSize { + if remainingBytes <= maxChunk { packet.Payload = data[sentBytes:] sentBytes += remainingBytes packet.FinalChunk = true } else { - packet.Payload = data[sentBytes : sentBytes+g.maxChunkSize] - sentBytes += g.maxChunkSize + packet.Payload = data[sentBytes : sentBytes+maxChunk] + sentBytes += maxChunk } if err := sendPacket(packet); err != nil { @@ -246,9 +206,9 @@ func (g *GoBackNConn) Recv() ([]byte, error) { msg *PacketData ) - g.recvTimeoutMu.RLock() + g.timeoutsMu.RLock() ticker := time.NewTimer(g.recvTimeout) - g.recvTimeoutMu.RUnlock() + g.timeoutsMu.RUnlock() defer ticker.Stop() for { @@ -276,21 +236,21 @@ func (g *GoBackNConn) start() { g.log.Debugf("Starting") pingTime := time.Duration(math.MaxInt64) - if g.pingTime != 0 { - pingTime = g.pingTime + if g.cfg.pingTime != 0 { + pingTime = g.cfg.pingTime } g.pingTicker = NewIntervalAwareForceTicker(pingTime) g.pingTicker.Resume() pongTime := time.Duration(math.MaxInt64) - if g.pongTime != 0 { - pongTime = g.pongTime + if g.cfg.pongTime != 0 { + pongTime = g.cfg.pongTime } g.pongTicker = NewIntervalAwareForceTicker(pongTime) - g.resendTicker = time.NewTicker(g.resendTimeout) + g.resendTicker = time.NewTicker(g.cfg.resendTimeout) g.wg.Add(1) go func() { @@ -382,7 +342,7 @@ func (g *GoBackNConn) sendPacket(ctx context.Context, msg Message) error { return fmt.Errorf("serialize error: %s", err) } - err = g.sendToStream(ctx, b) + err = g.cfg.sendToStream(ctx, b) if err != nil { return fmt.Errorf("error calling sendToStream: %s", err) } @@ -456,7 +416,7 @@ func (g *GoBackNConn) sendPacketsForever() error { for { // If the queue size is still less than N, we can // continue to add more packets to the queue. - if g.sendQueue.size() < g.n { + if g.sendQueue.size() < g.cfg.n { break } @@ -500,7 +460,7 @@ func (g *GoBackNConn) receivePacketsForever() error { // nolint:gocyclo default: } - b, err := g.recvFromStream(g.ctx) + b, err := g.cfg.recvFromStream(g.ctx) if err != nil { return fmt.Errorf("error receiving "+ "from recvFromStream: %s", err) @@ -537,7 +497,7 @@ func (g *GoBackNConn) receivePacketsForever() error { // nolint:gocyclo return err } - g.recvSeq = (g.recvSeq + 1) % g.s + g.recvSeq = (g.recvSeq + 1) % g.cfg.s // If the packet was a ping, then there is no // data to return to the above layer. @@ -567,7 +527,8 @@ func (g *GoBackNConn) receivePacketsForever() error { // nolint:gocyclo // If we recently sent a NACK for the same // sequence number then back off. if lastNackSeq == g.recvSeq && - time.Since(lastNackTime) < g.resendTimeout { + time.Since(lastNackTime) < + g.cfg.resendTimeout { continue } @@ -591,7 +552,7 @@ func (g *GoBackNConn) receivePacketsForever() error { // nolint:gocyclo case *PacketACK: gotValidACK := g.sendQueue.processACK(m.Seq) if gotValidACK { - g.resendTicker.Reset(g.resendTimeout) + g.resendTicker.Reset(g.cfg.resendTimeout) // Send a signal to indicate that new // ACKs have been received. diff --git a/gbn/gbn_server.go b/gbn/gbn_server.go index f510d1c..ab24b2f 100644 --- a/gbn/gbn_server.go +++ b/gbn/gbn_server.go @@ -14,13 +14,15 @@ import ( func NewServerConn(ctx context.Context, sendFunc sendBytesFunc, recvFunc recvBytesFunc, opts ...Option) (*GoBackNConn, error) { - conn := newGoBackNConn(ctx, sendFunc, recvFunc, true, DefaultN) + cfg := newConfig(sendFunc, recvFunc, DefaultN) // Apply functional options for _, o := range opts { - o(conn) + o(cfg) } + conn := newGoBackNConn(ctx, cfg, "server") + if err := conn.serverHandshake(); err != nil { if err := conn.Close(); err != nil { conn.log.Errorf("Error closing ServerConn: %v", err) @@ -60,7 +62,7 @@ func (g *GoBackNConn) serverHandshake() error { // nolint:gocyclo case <-recvNext: } - b, err := g.recvFromStream(g.ctx) + b, err := g.cfg.recvFromStream(g.ctx) if err != nil { errChan <- err return @@ -124,7 +126,7 @@ func (g *GoBackNConn) serverHandshake() error { // nolint:gocyclo return err } - if err = g.sendToStream(g.ctx, b); err != nil { + if err = g.cfg.sendToStream(g.ctx, b); err != nil { return err } @@ -140,7 +142,7 @@ func (g *GoBackNConn) serverHandshake() error { // nolint:gocyclo } select { - case <-time.After(g.handshakeTimeout): + case <-time.After(g.cfg.handshakeTimeout): g.log.Debugf("SYNCACK resendTimeout. Abort and wait " + "for client to re-initiate") continue diff --git a/gbn/options.go b/gbn/options.go index 2c24b74..2c1e8df 100644 --- a/gbn/options.go +++ b/gbn/options.go @@ -2,14 +2,14 @@ package gbn import "time" -type Option func(conn *GoBackNConn) +type Option func(conn *config) // WithMaxSendSize is used to set the maximum payload size in bytes per packet. // If set and a large payload comes through then it will be split up into // multiple packets with payloads no larger than the given maximum size. // A size of zero will disable splitting. func WithMaxSendSize(size int) Option { - return func(conn *GoBackNConn) { + return func(conn *config) { conn.maxChunkSize = size } } @@ -17,7 +17,7 @@ func WithMaxSendSize(size int) Option { // WithTimeout is used to set the resend timeout. This is the time to wait // for ACKs before resending the queue. func WithTimeout(timeout time.Duration) Option { - return func(conn *GoBackNConn) { + return func(conn *config) { conn.resendTimeout = timeout } } @@ -26,7 +26,7 @@ func WithTimeout(timeout time.Duration) Option { // If the timeout is reached without response from the peer then the handshake // will be aborted and restarted. func WithHandshakeTimeout(timeout time.Duration) Option { - return func(conn *GoBackNConn) { + return func(conn *config) { conn.handshakeTimeout = timeout } } @@ -38,9 +38,8 @@ func WithHandshakeTimeout(timeout time.Duration) Option { // the connection will be closed if the other side does not respond within // time duration. func WithKeepalivePing(ping, pong time.Duration) Option { - return func(conn *GoBackNConn) { + return func(conn *config) { conn.pingTime = ping conn.pongTime = pong - conn.pongWait = make(chan struct{}, 1) } }