From 4dc43f8f760685862219c67b035d76bf3a41f896 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Mon, 20 Nov 2023 13:03:59 +0200 Subject: [PATCH 1/4] gbn: use a prefixed logger Instead of manually needing to insert "isServer=" everywhere. --- gbn/gbn_client.go | 18 ++++++----- gbn/gbn_conn.go | 79 ++++++++++++++++++++++++++--------------------- gbn/gbn_server.go | 26 +++++++++------- gbn/queue.go | 31 +++++++++++++------ gbn/queue_test.go | 2 +- 5 files changed, 92 insertions(+), 64 deletions(-) diff --git a/gbn/gbn_client.go b/gbn/gbn_client.go index f42b41ca..db750e96 100644 --- a/gbn/gbn_client.go +++ b/gbn/gbn_client.go @@ -108,14 +108,15 @@ handshake: } // Send SYN - log.Debugf("Client sending SYN") + g.log.Debugf("Sending SYN") if err := g.sendToStream(g.ctx, msgBytes); err != nil { return err } for { // Wait for SYN - log.Debugf("Client waiting for SYN") + g.log.Debugf("Waiting for SYN") + select { case recvNext <- 1: case <-g.quit: @@ -128,7 +129,9 @@ handshake: var b []byte select { case <-time.After(g.handshakeTimeout): - log.Debugf("SYN resendTimeout. Resending SYN.") + g.log.Debugf("SYN resendTimeout. Resending " + + "SYN.") + continue handshake case <-g.quit: return nil @@ -144,7 +147,8 @@ handshake: return err } - log.Debugf("Client got %T", resp) + g.log.Debugf("Got %T", resp) + switch r := resp.(type) { case *PacketSYN: respSYN = r @@ -159,14 +163,14 @@ handshake: } } - log.Debugf("Client got SYN") + g.log.Debugf("Got SYN") if respSYN.N != g.n { return io.EOF } // Send SYNACK - log.Debugf("Client sending SYNACK") + g.log.Debugf("Sending SYNACK") synack, err := new(PacketSYNACK).Serialize() if err != nil { return err @@ -176,7 +180,7 @@ handshake: return err } - log.Debugf("Client Handshake complete") + g.log.Debugf("Handshake complete") return nil } diff --git a/gbn/gbn_conn.go b/gbn/gbn_conn.go index b6108053..39275ccf 100644 --- a/gbn/gbn_conn.go +++ b/gbn/gbn_conn.go @@ -8,6 +8,9 @@ import ( "math" "sync" "time" + + "github.com/btcsuite/btclog" + "github.com/lightningnetwork/lnd/build" ) var ( @@ -75,7 +78,7 @@ type GoBackNConn struct { recvTimeout time.Duration recvTimeoutMu sync.RWMutex - isServer bool + log btclog.Logger // handshakeTimeout is the time after which the server or client // will abort and restart the handshake if the expected response is @@ -118,6 +121,14 @@ func newGoBackNConn(ctx context.Context, sendFunc sendBytesFunc, ctxc, cancel := context.WithCancel(ctx) + // Construct a new prefixed logger. + identifier := "client" + if isServer { + identifier = "server" + } + prefix := fmt.Sprintf("(%s)", identifier) + plog := build.NewPrefixLog(prefix, log) + return &GoBackNConn{ n: n, s: n + 1, @@ -126,8 +137,7 @@ func newGoBackNConn(ctx context.Context, sendFunc sendBytesFunc, sendToStream: sendFunc, recvDataChan: make(chan *PacketData, n), sendDataChan: make(chan *PacketData), - isServer: isServer, - sendQueue: newQueue(n+1, defaultHandshakeTimeout), + sendQueue: newQueue(n+1, defaultHandshakeTimeout, plog), handshakeTimeout: defaultHandshakeTimeout, recvTimeout: DefaultRecvTimeout, sendTimeout: DefaultSendTimeout, @@ -136,6 +146,7 @@ func newGoBackNConn(ctx context.Context, sendFunc sendBytesFunc, remoteClosed: make(chan struct{}), ctx: ctxc, cancel: cancel, + log: plog, quit: make(chan struct{}), } } @@ -146,7 +157,7 @@ func (g *GoBackNConn) setN(n uint8) { g.n = n g.s = n + 1 g.recvDataChan = make(chan *PacketData, n) - g.sendQueue = newQueue(n+1, defaultHandshakeTimeout) + g.sendQueue = newQueue(n+1, defaultHandshakeTimeout, g.log) } // SetSendTimeout sets the timeout used in the Send function. @@ -262,7 +273,7 @@ func (g *GoBackNConn) Recv() ([]byte, error) { // start kicks off the various goroutines needed by GoBackNConn. // start should only be called once the handshake has been completed. func (g *GoBackNConn) start() { - log.Debugf("Starting (isServer=%v)", g.isServer) + g.log.Debugf("Starting") pingTime := time.Duration(math.MaxInt64) if g.pingTime != 0 { @@ -286,17 +297,17 @@ func (g *GoBackNConn) start() { defer func() { g.wg.Done() if err := g.Close(); err != nil { - log.Errorf("error closing GoBackNConn: %v", err) + g.log.Errorf("Error closing GoBackNConn: %v", + err) } }() err := g.receivePacketsForever() if err != nil { - log.Debugf("Error in receivePacketsForever "+ - "(isServer=%v): %v", g.isServer, err) + g.log.Debugf("Error in receivePacketsForever: %v", err) } - log.Debugf("receivePacketsForever stopped (isServer=%v)", - g.isServer) + + g.log.Debugf("receivePacketsForever stopped") }() g.wg.Add(1) @@ -304,25 +315,25 @@ func (g *GoBackNConn) start() { defer func() { g.wg.Done() if err := g.Close(); err != nil { - log.Errorf("error closing GoBackNConn: %v", err) + g.log.Errorf("Error closing GoBackNConn: %v", + err) } }() err := g.sendPacketsForever() if err != nil { - log.Debugf("Error in sendPacketsForever "+ - "(isServer=%v): %v", g.isServer, err) + g.log.Debugf("Error in sendPacketsForever: %v", err) } - log.Debugf("sendPacketsForever stopped (isServer=%v)", - g.isServer) + + g.log.Debugf("sendPacketsForever stopped") }() } // Close attempts to cleanly close the connection by sending a FIN message. func (g *GoBackNConn) Close() error { g.closeOnce.Do(func() { - log.Debugf("Closing GoBackNConn, isServer=%v", g.isServer) + g.log.Debugf("Closing GoBackNConn") // We close the quit channel to stop the usual operations of the // server. @@ -333,13 +344,14 @@ func (g *GoBackNConn) Close() error { select { case <-g.remoteClosed: default: - log.Tracef("Try sending FIN, isServer=%v", g.isServer) + g.log.Tracef("Try sending FIN") + ctxc, cancel := context.WithTimeout( g.ctx, finSendTimeout, ) defer cancel() if err := g.sendPacket(ctxc, &PacketFIN{}); err != nil { - log.Errorf("Error sending FIN: %v", err) + g.log.Errorf("Error sending FIN: %v", err) } } @@ -357,7 +369,7 @@ func (g *GoBackNConn) Close() error { g.resendTicker.Stop() } - log.Debugf("GBN is closed, isServer=%v", g.isServer) + g.log.Debugf("GBN is closed") }) return nil @@ -420,8 +432,7 @@ func (g *GoBackNConn) sendPacketsForever() error { g.pongTicker.Reset() g.pongTicker.Resume() - log.Tracef("Sending a PING packet (isServer=%v)", - g.isServer) + g.log.Tracef("Sending a PING packet") packet = &PacketData{ IsPing: true, @@ -437,7 +448,7 @@ func (g *GoBackNConn) sendPacketsForever() error { // send. g.sendQueue.addPacket(packet) - log.Tracef("Sending data %d", packet.Seq) + g.log.Tracef("Sending data %d", packet.Seq) if err := g.sendPacket(g.ctx, packet); err != nil { return err } @@ -449,7 +460,7 @@ func (g *GoBackNConn) sendPacketsForever() error { break } - log.Tracef("The queue is full.") + g.log.Tracef("The queue is full.") // The queue is full. We wait for a ACKs to arrive or // resend the queue after a timeout. @@ -516,7 +527,7 @@ func (g *GoBackNConn) receivePacketsForever() error { // nolint:gocyclo // an ACK message with that sequence number // and we bump the sequence number that we // expect of the next data packet. - log.Tracef("Got expected data %d", m.Seq) + g.log.Tracef("Got expected data %d", m.Seq) ack := &PacketACK{ Seq: m.Seq, @@ -551,7 +562,7 @@ func (g *GoBackNConn) receivePacketsForever() error { // nolint:gocyclo // it could be that we missed a previous packet. // In either case, we send a NACK with the // sequence number that we were expecting. - log.Tracef("Got unexpected data %d", m.Seq) + g.log.Tracef("Got unexpected data %d", m.Seq) // If we recently sent a NACK for the same // sequence number then back off. @@ -561,7 +572,7 @@ func (g *GoBackNConn) receivePacketsForever() error { // nolint:gocyclo continue } - log.Tracef("Sending NACK %d", g.recvSeq) + g.log.Tracef("Sending NACK %d", g.recvSeq) // Send a NACK with the expected sequence // number. @@ -603,9 +614,9 @@ func (g *GoBackNConn) receivePacketsForever() error { // nolint:gocyclo // then we ignore it. We must have received the ACK // for the sequence number in the meantime. if !inQueue { - log.Tracef("NACK seq %d is not in the queue. "+ - "Ignoring. (isServer=%v)", m.Seq, - g.isServer) + g.log.Tracef("NACK seq %d is not in the "+ + "queue. Ignoring", m.Seq) + continue } @@ -618,8 +629,7 @@ func (g *GoBackNConn) receivePacketsForever() error { // nolint:gocyclo } } - log.Tracef("Sending a resend signal (isServer=%v)", - g.isServer) + g.log.Tracef("Sending a resend signal") // Send a signal to indicate that new sends should pause // and the current queue should be resent instead. @@ -631,16 +641,15 @@ func (g *GoBackNConn) receivePacketsForever() error { // nolint:gocyclo case *PacketFIN: // A FIN packet indicates that the peer would like to // close the connection. - - log.Tracef("Received a FIN packet (isServer=%v)", - g.isServer) + g.log.Tracef("Received a FIN packet") close(g.remoteClosed) return errTransportClosing default: - return fmt.Errorf("received unexpected message: %T", msg) + return fmt.Errorf("received unexpected message: %T", + msg) } } } diff --git a/gbn/gbn_server.go b/gbn/gbn_server.go index 488a45db..6aea8156 100644 --- a/gbn/gbn_server.go +++ b/gbn/gbn_server.go @@ -23,7 +23,7 @@ func NewServerConn(ctx context.Context, sendFunc sendBytesFunc, if err := conn.serverHandshake(); err != nil { if err := conn.Close(); err != nil { - log.Errorf("error closing ServerConn: %v", err) + conn.log.Errorf("Error closing ServerConn: %v", err) } return nil, err @@ -38,9 +38,10 @@ func NewServerConn(ctx context.Context, sendFunc sendBytesFunc, // 1. The server waits for a SYN message from the client. // 2. The server then responds with a SYN message. // 3. The server waits for a SYNACK message from the client. -// 4a. If the server receives the SYNACK message before a resendTimeout, the hand -// is considered complete. -// 4b. If SYNACK is not received before a certain resendTimeout +// 4a. If the server receives the SYNACK message before a resendTimeout, the +// handshake is considered complete. +// 4b. If SYNACK is not received before a certain resendTimeout, then the +// handshake is aborted and the process is started from step 1 again. func (g *GoBackNConn) serverHandshake() error { // nolint:gocyclo recvChan := make(chan []byte) recvNext := make(chan int, 1) @@ -81,7 +82,7 @@ func (g *GoBackNConn) serverHandshake() error { // nolint:gocyclo var n uint8 for { - log.Debugf("Waiting for client SYN") + g.log.Debugf("Waiting for client SYN") select { case <-g.ctx.Done(): return nil @@ -108,13 +109,13 @@ func (g *GoBackNConn) serverHandshake() error { // nolint:gocyclo switch msg.(type) { case *PacketSYN: default: - log.Tracef("Expected SYN, got %T", msg) + g.log.Tracef("Expected SYN, got %T", msg) continue } recvClientSYN: - log.Debugf("Received client SYN. Sending back.") + g.log.Debugf("Received client SYN. Sending back.") n = msg.(*PacketSYN).N // Send SYN back @@ -129,7 +130,7 @@ func (g *GoBackNConn) serverHandshake() error { // nolint:gocyclo } // Wait for SYNACK - log.Debugf("Waiting for client SYNACK") + g.log.Debugf("Waiting for client SYNACK") select { case recvNext <- 1: case <-g.ctx.Done(): @@ -141,7 +142,7 @@ func (g *GoBackNConn) serverHandshake() error { // nolint:gocyclo select { case <-time.After(g.handshakeTimeout): - log.Debugf("SYNCACK resendTimeout. Abort and wait " + + g.log.Debugf("SYNCACK resendTimeout. Abort and wait " + "for client to re-initiate") continue case err := <-errChan: @@ -162,7 +163,7 @@ func (g *GoBackNConn) serverHandshake() error { // nolint:gocyclo case *PacketSYNACK: break case *PacketSYN: - log.Debugf("Received SYN. Resend SYN.") + g.log.Debugf("Received SYN. Resend SYN.") goto recvClientSYN default: return io.EOF @@ -170,12 +171,13 @@ func (g *GoBackNConn) serverHandshake() error { // nolint:gocyclo break } - log.Debugf("Received SYNACK") + g.log.Debugf("Received SYNACK") // Set all variables that are dependent on the value of N that we get // from the client g.setN(n) - log.Debugf("Handshake complete (Server)") + g.log.Debugf("Handshake complete (Server)") + return nil } diff --git a/gbn/queue.go b/gbn/queue.go index 23de65a0..1f2aa44a 100644 --- a/gbn/queue.go +++ b/gbn/queue.go @@ -3,6 +3,8 @@ package gbn import ( "sync" "time" + + "github.com/btcsuite/btclog" ) // queue is a fixed size queue with a sliding window that has a base and a top @@ -41,14 +43,24 @@ type queue struct { lastResend time.Time handshakeTimeout time.Duration + + log btclog.Logger } // newQueue creates a new queue. -func newQueue(s uint8, handshakeTimeout time.Duration) *queue { +func newQueue(s uint8, handshakeTimeout time.Duration, + logger btclog.Logger) *queue { + + log := log + if logger != nil { + log = logger + } + return &queue{ content: make([]*PacketData, s), s: s, handshakeTimeout: handshakeTimeout, + log: log, } } @@ -80,7 +92,7 @@ func (q *queue) addPacket(packet *PacketData) { // resend invokes the callback for each packet that needs to be re-sent. func (q *queue) resend(cb func(packet *PacketData) error) error { if time.Since(q.lastResend) < q.handshakeTimeout { - log.Tracef("Resent the queue recently.") + q.log.Tracef("Resent the queue recently.") return nil } @@ -103,7 +115,7 @@ func (q *queue) resend(cb func(packet *PacketData) error) error { return nil } - log.Tracef("Resending the queue") + q.log.Tracef("Resending the queue") for base != top { packet := q.content[base] @@ -113,7 +125,7 @@ func (q *queue) resend(cb func(packet *PacketData) error) error { } base = (base + 1) % q.s - log.Tracef("Resent %d", packet.Seq) + q.log.Tracef("Resent %d", packet.Seq) } return nil @@ -124,8 +136,9 @@ func (q *queue) processACK(seq uint8) bool { // If our queue is empty, an ACK should not have any effect. if q.size() == 0 { - log.Tracef("Received ack %d, but queue is empty. Ignoring.", + q.log.Tracef("Received ack %d, but queue is empty. Ignoring.", seq) + return false } @@ -137,7 +150,7 @@ func (q *queue) processACK(seq uint8) bool { // equal to the one we were expecting. So we increase our base // accordingly and send a signal to indicate that the queue size // has decreased. - log.Tracef("Received correct ack %d", seq) + q.log.Tracef("Received correct ack %d", seq) q.sequenceBase = (q.sequenceBase + 1) % q.s @@ -149,7 +162,7 @@ func (q *queue) processACK(seq uint8) bool { // This could be a duplicate ACK before or it could be that we just // missed the ACK for the current base and this is actually an ACK for // another packet in the queue. - log.Tracef("Received wrong ack %d, expected %d", seq, q.sequenceBase) + q.log.Tracef("Received wrong ack %d, expected %d", seq, q.sequenceBase) q.topMtx.RLock() defer q.topMtx.RUnlock() @@ -158,7 +171,7 @@ func (q *queue) processACK(seq uint8) bool { // just missed a previous ACK. We can bump the base to be equal to this // sequence number. if containsSequence(q.sequenceBase, q.sequenceTop, seq) { - log.Tracef("Sequence %d is in the queue. Bump the base.", seq) + q.log.Tracef("Sequence %d is in the queue. Bump the base.", seq) q.sequenceBase = (seq + 1) % q.s @@ -178,7 +191,7 @@ func (q *queue) processNACK(seq uint8) (bool, bool) { q.topMtx.RLock() defer q.topMtx.RUnlock() - log.Tracef("Received NACK %d", seq) + q.log.Tracef("Received NACK %d", seq) // If the NACK is the same as sequenceTop, it probably means that queue // was sent successfully, but we just missed the necessary ACKs. So we diff --git a/gbn/queue_test.go b/gbn/queue_test.go index b91fc141..fbcace16 100644 --- a/gbn/queue_test.go +++ b/gbn/queue_test.go @@ -7,7 +7,7 @@ import ( ) func TestQueueSize(t *testing.T) { - q := newQueue(4, 0) + q := newQueue(4, 0, nil) require.Equal(t, uint8(0), q.size()) From df80c87d52ac650385cc4404d663d08055384356 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Wed, 22 Nov 2023 13:40:16 +0200 Subject: [PATCH 2/4] mailbox: unify Client creation logic --- mailbox/client.go | 68 +++++++++++++++++++++++++++++------------------ 1 file changed, 42 insertions(+), 26 deletions(-) diff --git a/mailbox/client.go b/mailbox/client.go index 4e9db7dc..cb37ed1d 100644 --- a/mailbox/client.go +++ b/mailbox/client.go @@ -11,6 +11,17 @@ import ( "google.golang.org/grpc" ) +// ClientOption is the signature of a Client functional option. +type ClientOption func(*Client) + +// WithGrpcConn initialised the grpc client of the Client using the given +// connection. +func WithGrpcConn(conn *grpc.ClientConn) ClientOption { + return func(client *Client) { + client.grpcClient = hashmailrpc.NewHashMailClient(conn) + } +} + // Client manages the mailboxConn it holds and refreshes it on connection // retries. type Client struct { @@ -26,33 +37,49 @@ type Client struct { sid [64]byte - ctx context.Context + ctx context.Context //nolint:containedctx } -// NewGrpcClient creates a new Client object which will handle the mailbox -// connection and will use grpc streams to connect to the mailbox. -func NewGrpcClient(ctx context.Context, serverHost string, connData *ConnData, - dialOpts ...grpc.DialOption) (*Client, error) { - - mailboxGrpcConn, err := grpc.Dial(serverHost, dialOpts...) - if err != nil { - return nil, fmt.Errorf("unable to connect to RPC server: %v", - err) - } +// NewClient creates a new Client object which will handle the mailbox +// connection. +func NewClient(ctx context.Context, serverHost string, connData *ConnData, + opts ...ClientOption) (*Client, error) { sid, err := connData.SID() if err != nil { return nil, err } - return &Client{ + c := &Client{ ctx: ctx, serverHost: serverHost, connData: connData, - grpcClient: hashmailrpc.NewHashMailClient(mailboxGrpcConn), status: ClientStatusNotConnected, sid: sid, - }, nil + } + + // Apply functional options. + for _, o := range opts { + o(c) + } + + return c, nil +} + +// NewGrpcClient creates a new Client object which will handle the mailbox +// connection and will use grpc streams to connect to the mailbox. +func NewGrpcClient(ctx context.Context, serverHost string, connData *ConnData, + dialOpts ...grpc.DialOption) (*Client, error) { + + mailboxGrpcConn, err := grpc.Dial(serverHost, dialOpts...) + if err != nil { + return nil, fmt.Errorf("unable to connect to RPC server: %w", + err) + } + + return NewClient( + ctx, serverHost, connData, WithGrpcConn(mailboxGrpcConn), + ) } // NewWebsocketsClient creates a new Client object which will handle the mailbox @@ -60,18 +87,7 @@ func NewGrpcClient(ctx context.Context, serverHost string, connData *ConnData, func NewWebsocketsClient(ctx context.Context, serverHost string, connData *ConnData) (*Client, error) { - sid, err := connData.SID() - if err != nil { - return nil, err - } - - return &Client{ - ctx: ctx, - serverHost: serverHost, - connData: connData, - status: ClientStatusNotConnected, - sid: sid, - }, nil + return NewClient(ctx, serverHost, connData) } // Dial returns a net.Conn abstraction over the mailbox connection. Dial is From 35c0e8d4a5093060f83f22457dba5bb67ce2c851 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Wed, 22 Nov 2023 13:59:20 +0200 Subject: [PATCH 3/4] mailbox: use prefixed logger --- mailbox/client.go | 14 ++++++--- mailbox/client_conn.go | 47 ++++++++++++++++++---------- mailbox/log.go | 13 ++++++++ mailbox/server.go | 19 +++++++---- mailbox/server_conn.go | 71 ++++++++++++++++++++++++++---------------- 5 files changed, 109 insertions(+), 55 deletions(-) diff --git a/mailbox/client.go b/mailbox/client.go index cb37ed1d..14ec95a1 100644 --- a/mailbox/client.go +++ b/mailbox/client.go @@ -7,6 +7,7 @@ import ( "net" "sync" + "github.com/btcsuite/btclog" "github.com/lightninglabs/lightning-node-connect/hashmailrpc" "google.golang.org/grpc" ) @@ -38,6 +39,8 @@ type Client struct { sid [64]byte ctx context.Context //nolint:containedctx + + log btclog.Logger } // NewClient creates a new Client object which will handle the mailbox @@ -56,6 +59,7 @@ func NewClient(ctx context.Context, serverHost string, connData *ConnData, connData: connData, status: ClientStatusNotConnected, sid: sid, + log: newPrefixedLogger(false), } // Apply functional options. @@ -98,12 +102,12 @@ func (c *Client) Dial(_ context.Context, _ string) (net.Conn, error) { // If there is currently an active connection, block here until the // previous connection as been closed. if c.mailboxConn != nil { - log.Debugf("Dial: have existing mailbox connection, waiting") + c.log.Debugf("Dial: have existing mailbox connection, waiting") <-c.mailboxConn.Done() - log.Debugf("Dial: done with existing conn") + c.log.Debugf("Dial: done with existing conn") } - log.Debugf("Client: Dialing...") + c.log.Debugf("Dialing...") sid, err := c.connData.SID() if err != nil { @@ -115,7 +119,7 @@ func (c *Client) Dial(_ context.Context, _ string) (net.Conn, error) { if !bytes.Equal(c.sid[:], sid[:]) && c.mailboxConn != nil { err := c.mailboxConn.Close() if err != nil { - log.Errorf("could not close mailbox conn: %v", err) + c.log.Errorf("Could not close mailbox conn: %v", err) } c.mailboxConn = nil @@ -126,7 +130,7 @@ func (c *Client) Dial(_ context.Context, _ string) (net.Conn, error) { if c.mailboxConn == nil { mailboxConn, err := NewClientConn( c.ctx, c.sid, c.serverHost, c.grpcClient, - func(status ClientStatus) { + c.log, func(status ClientStatus) { c.statusMu.Lock() c.status = status c.statusMu.Unlock() diff --git a/mailbox/client_conn.go b/mailbox/client_conn.go index b6428347..6355a942 100644 --- a/mailbox/client_conn.go +++ b/mailbox/client_conn.go @@ -8,6 +8,7 @@ import ( "sync" "time" + "github.com/btcsuite/btclog" "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" "github.com/lightninglabs/lightning-node-connect/gbn" "github.com/lightninglabs/lightning-node-connect/hashmailrpc" @@ -124,13 +125,15 @@ type ClientConn struct { quit chan struct{} cancel func() closeOnce sync.Once + + log btclog.Logger } // NewClientConn creates a new client connection with the given receive and send // session identifiers. The context given as the first parameter will be used // throughout the connection lifetime. func NewClientConn(ctx context.Context, sid [64]byte, serverHost string, - client hashmailrpc.HashMailClient, + client hashmailrpc.HashMailClient, logger btclog.Logger, onNewStatus func(status ClientStatus)) (*ClientConn, error) { receiveSID := GetSID(sid, true) @@ -141,7 +144,7 @@ func NewClientConn(ctx context.Context, sid [64]byte, serverHost string, sendSID: sendSID[:], } - log.Debugf("New client conn, read_stream=%x, write_stream=%x", + logger.Debugf("New conn, read_stream=%x, write_stream=%x", receiveSID[:], sendSID[:]) ctxc, cancel := context.WithCancel(ctx) @@ -166,6 +169,7 @@ func NewClientConn(ctx context.Context, sid [64]byte, serverHost string, onNewStatus: onNewStatus, quit: make(chan struct{}), cancel: cancel, + log: logger, } c.connKit = &connKit{ ctx: ctxc, @@ -201,10 +205,11 @@ func RefreshClientConn(ctx context.Context, c *ClientConn) (*ClientConn, c.statusMu.Lock() defer c.statusMu.Unlock() - log.Debugf("Refreshing client conn, read_stream=%x, write_stream=%x", + c.log.Debugf("Refreshing client conn, read_stream=%x, write_stream=%x", c.receiveSID[:], c.sendSID[:]) cc := &ClientConn{ + log: c.log, transport: c.transport.Refresh(), status: ClientStatusNotConnected, onNewStatus: c.onNewStatus, @@ -299,7 +304,7 @@ func (c *ClientConn) recv(ctx context.Context) ([]byte, error) { return nil, err } - log.Debugf("Client: got failure on receive "+ + c.log.Debugf("Got failure on receive "+ "socket/stream, re-trying: %v", err) c.setStatus(errStatus) @@ -343,8 +348,8 @@ func (c *ClientConn) send(ctx context.Context, payload []byte) error { return err } - log.Debugf("Client: got failure on send "+ - "socket/stream, re-trying: %v", err) + c.log.Debugf("Got failure on send socket/stream, "+ + "re-trying: %v", err) c.setStatus(errStatus) c.createSendMailBox(ctx, retryWait) @@ -377,13 +382,14 @@ func (c *ClientConn) createReceiveMailBox(ctx context.Context, waiter.Wait() if err := c.transport.ConnectReceive(ctx); err != nil { - log.Errorf("Client: error connecting to receive "+ + c.log.Errorf("Error connecting to receive "+ "socket/stream: %v", err) continue } - log.Debugf("Client: receive mailbox initialized") + c.log.Debugf("Receive mailbox initialized") + return } } @@ -406,14 +412,17 @@ func (c *ClientConn) createSendMailBox(ctx context.Context, waiter.Wait() - log.Debugf("Client: Attempting to create send socket/stream") + c.log.Debugf("Attempting to create send socket/stream") + if err := c.transport.ConnectSend(ctx); err != nil { - log.Debugf("Client: error connecting to send "+ + c.log.Debugf("Error connecting to send "+ "stream/socket %v", err) + continue } - log.Debugf("Client: Connected to send socket/stream") + c.log.Debugf("Connected to send socket/stream") + return } } @@ -460,11 +469,11 @@ func (c *ClientConn) SetSendTimeout(timeout time.Duration) { func (c *ClientConn) Close() error { var returnErr error c.closeOnce.Do(func() { - log.Debugf("Closing client connection") + c.log.Debugf("Closing connection") if c.gbnConn != nil { if err := c.gbnConn.Close(); err != nil { - log.Debugf("Error closing gbn connection: %v", + c.log.Debugf("Error closing gbn connection: %v", err) returnErr = err @@ -472,17 +481,21 @@ func (c *ClientConn) Close() error { } c.receiveMu.Lock() - log.Debugf("closing receive stream/socket") + c.log.Debugf("Closing receive stream/socket") if err := c.transport.CloseReceive(); err != nil { - log.Errorf("Error closing receive stream/socket: %v", err) + c.log.Errorf("Error closing receive stream/socket: %v", + err) + returnErr = err } c.receiveMu.Unlock() c.sendMu.Lock() - log.Debugf("closing send stream/socket") + c.log.Debugf("Closing send stream/socket") if err := c.transport.CloseSend(); err != nil { - log.Errorf("Error closing send stream/socket: %v", err) + c.log.Errorf("Error closing send stream/socket: %v", + err) + returnErr = err } c.sendMu.Unlock() diff --git a/mailbox/log.go b/mailbox/log.go index 13cbb5f3..98f8a207 100644 --- a/mailbox/log.go +++ b/mailbox/log.go @@ -1,6 +1,8 @@ package mailbox import ( + "fmt" + "github.com/btcsuite/btclog" "github.com/lightningnetwork/lnd/build" "google.golang.org/grpc/grpclog" @@ -29,6 +31,17 @@ func UseLogger(logger btclog.Logger) { log = logger } +// nePrefixedLogger constructs a new prefixed logger. +func newPrefixedLogger(isServer bool) *build.PrefixLog { + identifier := "client" + if isServer { + identifier = "server" + } + prefix := fmt.Sprintf("(%s)", identifier) + + return build.NewPrefixLog(prefix, log) +} + // GrpcLogLogger is a wrapper around a btclog logger to make it compatible with // the grpclog logger package. By default we downgrade the info level to debug // to reduce the verbosity of the logger. diff --git a/mailbox/server.go b/mailbox/server.go index f810fbc8..c80314e2 100644 --- a/mailbox/server.go +++ b/mailbox/server.go @@ -7,6 +7,7 @@ import ( "io" "net" + "github.com/btcsuite/btclog" "github.com/lightninglabs/lightning-node-connect/hashmailrpc" "google.golang.org/grpc" ) @@ -30,6 +31,8 @@ type Server struct { quit chan struct{} cancel func() + + log btclog.Logger } func NewServer(serverHost string, connData *ConnData, @@ -55,6 +58,7 @@ func NewServer(serverHost string, connData *ConnData, connData: connData, sid: sid, onNewStatus: onNewStatus, + log: newPrefixedLogger(true), quit: make(chan struct{}), } @@ -79,12 +83,14 @@ func (s *Server) Accept() (net.Conn, error) { // If there is currently an active connection, block here until the // previous connection as been closed. if s.mailboxConn != nil { - log.Debugf("Accept: have existing mailbox connection, waiting") + s.log.Debugf("Accept: have existing mailbox connection, " + + "waiting") + select { case <-s.quit: return nil, io.EOF case <-s.mailboxConn.Done(): - log.Debugf("Accept: done with existing conn") + s.log.Debugf("Accept: done with existing conn") } } @@ -98,7 +104,7 @@ func (s *Server) Accept() (net.Conn, error) { if !bytes.Equal(s.sid[:], sid[:]) && s.mailboxConn != nil { err := s.mailboxConn.Stop() if err != nil { - log.Errorf("could not close mailbox conn: %v", err) + s.log.Errorf("Could not close mailbox conn: %v", err) } s.mailboxConn = nil @@ -110,7 +116,8 @@ func (s *Server) Accept() (net.Conn, error) { // otherwise, we just refresh the ServerConn. if s.mailboxConn == nil { mailboxConn, err := NewServerConn( - s.ctx, s.serverHost, s.client, sid, s.onNewStatus, + s.ctx, s.serverHost, s.client, sid, s.log, + s.onNewStatus, ) if err != nil { return nil, &temporaryError{err} @@ -143,13 +150,13 @@ func (e *temporaryError) Temporary() bool { } func (s *Server) Close() error { - log.Debugf("conn being closed") + s.log.Debugf("Conn being closed") close(s.quit) if s.mailboxConn != nil { if err := s.mailboxConn.Stop(); err != nil { - log.Errorf("error closing mailboxConn %v", err) + s.log.Errorf("Error closing mailboxConn %v", err) } } s.cancel() diff --git a/mailbox/server_conn.go b/mailbox/server_conn.go index a1c89305..ac75055c 100644 --- a/mailbox/server_conn.go +++ b/mailbox/server_conn.go @@ -6,6 +6,7 @@ import ( "sync" "time" + "github.com/btcsuite/btclog" "github.com/lightninglabs/lightning-node-connect/gbn" "github.com/lightninglabs/lightning-node-connect/hashmailrpc" "google.golang.org/grpc/codes" @@ -55,6 +56,8 @@ type ServerConn struct { onNewStatus func(status ServerStatus) statusMu sync.Mutex + log btclog.Logger + cancel func() quit chan struct{} @@ -64,7 +67,7 @@ type ServerConn struct { // NewServerConn creates a new net.Conn compatible server connection that uses // a gRPC based connection to tunnel traffic over a mailbox server. func NewServerConn(ctx context.Context, serverHost string, - client hashmailrpc.HashMailClient, sid [64]byte, + client hashmailrpc.HashMailClient, sid [64]byte, logger btclog.Logger, onNewStatus func(status ServerStatus)) (*ServerConn, error) { ctxc, cancel := context.WithCancel(ctx) @@ -84,6 +87,7 @@ func NewServerConn(ctx context.Context, serverHost string, ), }, status: ServerStatusNotConnected, + log: logger, onNewStatus: onNewStatus, } c.connKit = &connKit{ @@ -94,7 +98,7 @@ func NewServerConn(ctx context.Context, serverHost string, sendSID: sendSID, } - log.Debugf("ServerConn: creating gbn, waiting for sync") + logger.Debugf("Creating gbn, waiting for sync") var err error c.gbnConn, err = gbn.NewServerConn( ctxc, c.sendToStream, c.recvFromStream, c.gbnOptions..., @@ -102,7 +106,7 @@ func NewServerConn(ctx context.Context, serverHost string, if err != nil { return nil, err } - log.Debugf("ServerConn: done creating gbn") + logger.Debugf("Done creating gbn") return c, nil } @@ -128,6 +132,7 @@ func RefreshServerConn(s *ServerConn) (*ServerConn, error) { cancel: s.cancel, status: ServerStatusNotConnected, onNewStatus: s.onNewStatus, + log: s.log, quit: make(chan struct{}), } @@ -140,7 +145,7 @@ func RefreshServerConn(s *ServerConn) (*ServerConn, error) { sendSID: s.connKit.sendSID, } - log.Debugf("ServerConn: creating gbn") + s.log.Debugf("ServerConn: creating gbn") var err error sc.gbnConn, err = gbn.NewServerConn( sc.ctx, sc.sendToStream, sc.recvFromStream, sc.gbnOptions..., @@ -149,7 +154,7 @@ func RefreshServerConn(s *ServerConn) (*ServerConn, error) { return nil, err } - log.Debugf("ServerConn: done creating gbn") + s.log.Debugf("ServerConn: done creating gbn") return sc, nil } @@ -174,7 +179,7 @@ func (c *ServerConn) recvFromStream(ctx context.Context) ([]byte, error) { c.receiveStreamMu.Lock() controlMsg, err := c.receiveStream.Recv() if err != nil { - log.Debugf("Server: got failure on receive socket, "+ + c.log.Debugf("Got failure on receive socket, "+ "re-trying: %v", err) c.setStatus(ServerStatusNotConnected) @@ -215,7 +220,7 @@ func (c *ServerConn) sendToStream(ctx context.Context, payload []byte) error { Msg: payload, }) if err != nil { - log.Debugf("Server: got failure on send socket, "+ + c.log.Debugf("Got failure on send socket, "+ "re-trying: %v", err) c.setStatus(ServerStatusNotConnected) @@ -289,8 +294,8 @@ func (c *ServerConn) createReceiveMailBox(ctx context.Context, // Create receive mailbox and get receive stream. err := initAccountCipherBox(ctx, c.client, c.receiveSID) if err != nil && !isErrAlreadyExists(err) { - log.Debugf("Server: failed to re-create read stream "+ - "mbox: %v", err) + c.log.Debugf("Failed to re-create read stream mbox: %v", + err) continue } @@ -306,8 +311,7 @@ func (c *ServerConn) createReceiveMailBox(ctx context.Context, } readStream, err := c.client.RecvStream(ctx, streamDesc) if err != nil { - log.Debugf("Server: failed to create read stream: %w", - err) + c.log.Debugf("Failed to create read stream: %w", err) continue } @@ -315,7 +319,8 @@ func (c *ServerConn) createReceiveMailBox(ctx context.Context, c.setStatus(ServerStatusIdle) c.receiveStream = readStream - log.Debugf("Server: receive mailbox created") + c.log.Debugf("Receive mailbox created") + return } } @@ -341,7 +346,8 @@ func (c *ServerConn) createSendMailBox(ctx context.Context, // Create send mailbox and get send stream. err := initAccountCipherBox(ctx, c.client, c.sendSID) if err != nil && !isErrAlreadyExists(err) { - log.Debugf("error creating send cipher box: %v", err) + c.log.Debugf("Error creating send cipher box: %v", err) + continue } c.sendBoxCreated = true @@ -353,12 +359,14 @@ func (c *ServerConn) createSendMailBox(ctx context.Context, // and exit if needed. writeStream, err := c.client.SendStream(ctx) if err != nil { - log.Debugf("unable to create send stream: %w", err) + c.log.Debugf("Unable to create send stream: %w", err) + continue } c.sendStream = writeStream - log.Debugf("Server: Send mailbox created") + c.log.Debugf("Send mailbox created") + return } } @@ -368,19 +376,23 @@ func (c *ServerConn) createSendMailBox(ctx context.Context, func (c *ServerConn) Stop() error { var returnErr error if err := c.Close(); err != nil { - log.Errorf("error closing mailbox") + c.log.Errorf("Error closing mailbox") + returnErr = err } if c.receiveBoxCreated { - if err := delCipherBox(c.ctx, c.client, c.receiveSID); err != nil { - log.Errorf("error removing receive cipher box: %v", err) + err := delCipherBox(c.ctx, c.client, c.receiveSID) + if err != nil { + c.log.Errorf("Error removing receive cipher box: %v", + err) + returnErr = err } } if c.sendBoxCreated { if err := delCipherBox(c.ctx, c.client, c.sendSID); err != nil { - log.Errorf("error removing send cipher box: %v", err) + c.log.Errorf("Error removing send cipher box: %v", err) returnErr = err } } @@ -403,34 +415,39 @@ func (c *ServerConn) Close() error { var returnErr error c.closeOnce.Do(func() { - log.Debugf("Server connection is closing") + c.log.Debugf("Connection is closing") if c.gbnConn != nil { if err := c.gbnConn.Close(); err != nil { - log.Debugf("Error closing gbn connection in " + - "server conn") + c.log.Debugf("Error closing gbn connection " + + "in server conn") + returnErr = err } } if c.receiveStream != nil { - log.Debugf("closing receive stream") + c.log.Debugf("Closing receive stream") if err := c.receiveStream.CloseSend(); err != nil { - log.Errorf("error closing receive stream: %v", err) + c.log.Errorf("Error closing receive stream: %v", + err) + returnErr = err } } if c.sendStream != nil { - log.Debugf("closing send stream") + c.log.Debugf("Closing send stream") if err := c.sendStream.CloseSend(); err != nil { - log.Errorf("error closing send stream: %v", err) + c.log.Errorf("Error closing send stream: %v", + err) + returnErr = err } } close(c.quit) - log.Debugf("Server connection closed") + c.log.Debugf("Connection closed") }) return returnErr From 72e5adfada3458c42e393e7e5d94a3ad4b0d56f8 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Mon, 20 Nov 2023 13:32:41 +0200 Subject: [PATCH 4/4] gbn: split out config values into a config struct --- gbn/config.go | 62 ++++++++++++++++++++++ gbn/gbn_client.go | 20 ++++---- gbn/gbn_conn.go | 127 ++++++++++++++++------------------------------ gbn/gbn_server.go | 12 +++-- gbn/options.go | 11 ++-- 5 files changed, 129 insertions(+), 103 deletions(-) create mode 100644 gbn/config.go diff --git a/gbn/config.go b/gbn/config.go new file mode 100644 index 00000000..4d14f7f2 --- /dev/null +++ b/gbn/config.go @@ -0,0 +1,62 @@ +package gbn + +import "time" + +// config holds the configuration values for an instance of GoBackNConn. +type config struct { + // n is the window size. The sender can send a maximum of n packets + // before requiring an ack from the receiver for the first packet in + // the window. The value of n is chosen by the client during the + // GoBN handshake. + n uint8 + + // s is the maximum sequence number used to label packets. Packets + // are labelled with incrementing sequence numbers modulo s. + // s must be strictly larger than the window size, n. This + // is so that the receiver can tell if the sender is resending the + // previous window (maybe the sender did not receive the acks) or if + // they are sending the next window. If s <= n then there would be + // no way to tell. + s uint8 + + // maxChunkSize is the maximum payload size in bytes allowed per + // message. If the payload to be sent is larger than maxChunkSize then + // the payload will be split between multiple packets. + // If maxChunkSize is zero then it is disabled and data won't be split + // between packets. + maxChunkSize int + + // resendTimeout is the duration that will be waited before resending + // the packets in the current queue. + resendTimeout time.Duration + + // recvFromStream is the function that will be used to acquire the next + // available packet. + recvFromStream recvBytesFunc + + // sendToStream is the function that will be used to send over our next + // packet. + sendToStream sendBytesFunc + + // handshakeTimeout is the time after which the server or client + // will abort and restart the handshake if the expected response is + // not received from the peer. + handshakeTimeout time.Duration + + pingTime time.Duration + pongTime time.Duration +} + +// newConfig constructs a new config struct. +func newConfig(sendFunc sendBytesFunc, recvFunc recvBytesFunc, + n uint8) *config { + + return &config{ + n: n, + s: n + 1, + recvFromStream: recvFunc, + sendToStream: sendFunc, + resendTimeout: defaultResendTimeout, + handshakeTimeout: defaultHandshakeTimeout, + } +} diff --git a/gbn/gbn_client.go b/gbn/gbn_client.go index db750e96..c1032e26 100644 --- a/gbn/gbn_client.go +++ b/gbn/gbn_client.go @@ -21,16 +21,18 @@ func NewClientConn(ctx context.Context, n uint8, sendFunc sendBytesFunc, math.MaxUint8) } - conn := newGoBackNConn(ctx, sendFunc, receiveFunc, false, n) + cfg := newConfig(sendFunc, receiveFunc, n) // Apply functional options for _, o := range opts { - o(conn) + o(cfg) } + conn := newGoBackNConn(ctx, cfg, "client") + if err := conn.clientHandshake(); err != nil { if err := conn.Close(); err != nil { - log.Errorf("error closing gbn ClientConn: %v", err) + conn.log.Errorf("error closing gbn ClientConn: %v", err) } return nil, err } @@ -76,7 +78,7 @@ func (g *GoBackNConn) clientHandshake() error { case <-recvNext: } - b, err := g.recvFromStream(g.ctx) + b, err := g.cfg.recvFromStream(g.ctx) if err != nil { errChan <- err return @@ -101,7 +103,7 @@ func (g *GoBackNConn) clientHandshake() error { handshake: for { // start Handshake - msg := &PacketSYN{N: g.n} + msg := &PacketSYN{N: g.cfg.n} msgBytes, err := msg.Serialize() if err != nil { return err @@ -109,7 +111,7 @@ handshake: // Send SYN g.log.Debugf("Sending SYN") - if err := g.sendToStream(g.ctx, msgBytes); err != nil { + if err := g.cfg.sendToStream(g.ctx, msgBytes); err != nil { return err } @@ -128,7 +130,7 @@ handshake: var b []byte select { - case <-time.After(g.handshakeTimeout): + case <-time.After(g.cfg.handshakeTimeout): g.log.Debugf("SYN resendTimeout. Resending " + "SYN.") @@ -165,7 +167,7 @@ handshake: g.log.Debugf("Got SYN") - if respSYN.N != g.n { + if respSYN.N != g.cfg.n { return io.EOF } @@ -176,7 +178,7 @@ handshake: return err } - if err := g.sendToStream(g.ctx, synack); err != nil { + if err := g.cfg.sendToStream(g.ctx, synack); err != nil { return err } diff --git a/gbn/gbn_conn.go b/gbn/gbn_conn.go index 39275ccf..e1a8e802 100644 --- a/gbn/gbn_conn.go +++ b/gbn/gbn_conn.go @@ -33,27 +33,7 @@ type sendBytesFunc func(ctx context.Context, b []byte) error type recvBytesFunc func(ctx context.Context) ([]byte, error) type GoBackNConn struct { - // n is the window size. The sender can send a maximum of n packets - // before requiring an ack from the receiver for the first packet in - // the window. The value of n is chosen by the client during the - // GoBN handshake. - n uint8 - - // s is the maximum sequence number used to label packets. Packets - // are labelled with incrementing sequence numbers modulo s. - // s must be strictly larger than the window size, n. This - // is so that the receiver can tell if the sender is resending the - // previous window (maybe the sender did not receive the acks) or if - // they are sending the next window. If s <= n then there would be - // no way to tell. - s uint8 - - // maxChunkSize is the maximum payload size in bytes allowed per - // message. If the payload to be sent is larger than maxChunkSize then - // the payload will be split between multiple packets. - // If maxChunkSize is zero then it is disabled and data won't be split - // between packets. - maxChunkSize int + cfg *config sendQueue *queue @@ -61,30 +41,17 @@ type GoBackNConn struct { // sequence that we have received. recvSeq uint8 - // resendTimeout is the duration that will be waited before resending - // the packets in the current queue. - resendTimeout time.Duration - resendTicker *time.Ticker - - recvFromStream recvBytesFunc - sendToStream sendBytesFunc + resendTicker *time.Ticker recvDataChan chan *PacketData sendDataChan chan *PacketData - sendTimeout time.Duration - sendTimeoutMu sync.RWMutex - - recvTimeout time.Duration - recvTimeoutMu sync.RWMutex + sendTimeout time.Duration + recvTimeout time.Duration + timeoutsMu sync.RWMutex log btclog.Logger - // handshakeTimeout is the time after which the server or client - // will abort and restart the handshake if the expected response is - // not received from the peer. - handshakeTimeout time.Duration - // receivedACKSignal channel is used to signal that the queue size has // been decreased. receivedACKSignal chan struct{} @@ -94,11 +61,8 @@ type GoBackNConn struct { // that this channel should only be listened on in one place. resendSignal chan struct{} - pingTime time.Duration - pongTime time.Duration pingTicker *IntervalAwareForceTicker pongTicker *IntervalAwareForceTicker - pongWait chan struct{} ctx context.Context //nolint:containedctx cancel func() @@ -116,29 +80,22 @@ type GoBackNConn struct { // newGoBackNConn creates a GoBackNConn instance with all the members which // are common between client and server initialised. -func newGoBackNConn(ctx context.Context, sendFunc sendBytesFunc, - recvFunc recvBytesFunc, isServer bool, n uint8) *GoBackNConn { +func newGoBackNConn(ctx context.Context, cfg *config, + loggerPrefix string) *GoBackNConn { ctxc, cancel := context.WithCancel(ctx) // Construct a new prefixed logger. - identifier := "client" - if isServer { - identifier = "server" - } - prefix := fmt.Sprintf("(%s)", identifier) + prefix := fmt.Sprintf("(%s)", loggerPrefix) plog := build.NewPrefixLog(prefix, log) return &GoBackNConn{ - n: n, - s: n + 1, - resendTimeout: defaultResendTimeout, - recvFromStream: recvFunc, - sendToStream: sendFunc, - recvDataChan: make(chan *PacketData, n), - sendDataChan: make(chan *PacketData), - sendQueue: newQueue(n+1, defaultHandshakeTimeout, plog), - handshakeTimeout: defaultHandshakeTimeout, + cfg: cfg, + recvDataChan: make(chan *PacketData, cfg.n), + sendDataChan: make(chan *PacketData), + sendQueue: newQueue( + cfg.n+1, defaultHandshakeTimeout, plog, + ), recvTimeout: DefaultRecvTimeout, sendTimeout: DefaultSendTimeout, receivedACKSignal: make(chan struct{}), @@ -154,24 +111,24 @@ func newGoBackNConn(ctx context.Context, sendFunc sendBytesFunc, // setN sets the current N to use. This _must_ be set before the handshake is // completed. func (g *GoBackNConn) setN(n uint8) { - g.n = n - g.s = n + 1 + g.cfg.n = n + g.cfg.s = n + 1 g.recvDataChan = make(chan *PacketData, n) g.sendQueue = newQueue(n+1, defaultHandshakeTimeout, g.log) } // SetSendTimeout sets the timeout used in the Send function. func (g *GoBackNConn) SetSendTimeout(timeout time.Duration) { - g.sendTimeoutMu.Lock() - defer g.sendTimeoutMu.Unlock() + g.timeoutsMu.Lock() + defer g.timeoutsMu.Unlock() g.sendTimeout = timeout } // SetRecvTimeout sets the timeout used in the Recv function. func (g *GoBackNConn) SetRecvTimeout(timeout time.Duration) { - g.recvTimeoutMu.Lock() - defer g.recvTimeoutMu.Unlock() + g.timeoutsMu.Lock() + defer g.timeoutsMu.Unlock() g.recvTimeout = timeout } @@ -185,9 +142,9 @@ func (g *GoBackNConn) Send(data []byte) error { default: } - g.sendTimeoutMu.RLock() + g.timeoutsMu.RLock() ticker := time.NewTimer(g.sendTimeout) - g.sendTimeoutMu.RUnlock() + g.timeoutsMu.RUnlock() defer ticker.Stop() sendPacket := func(packet *PacketData) error { @@ -201,7 +158,7 @@ func (g *GoBackNConn) Send(data []byte) error { } } - if g.maxChunkSize == 0 { + if g.cfg.maxChunkSize == 0 { // Splitting is disabled. return sendPacket(&PacketData{ Payload: data, @@ -210,18 +167,21 @@ func (g *GoBackNConn) Send(data []byte) error { } // Splitting is enabled. Split into packets no larger than maxChunkSize. - sentBytes := 0 + var ( + sentBytes = 0 + maxChunk = g.cfg.maxChunkSize + ) for sentBytes < len(data) { packet := &PacketData{} remainingBytes := len(data) - sentBytes - if remainingBytes <= g.maxChunkSize { + if remainingBytes <= maxChunk { packet.Payload = data[sentBytes:] sentBytes += remainingBytes packet.FinalChunk = true } else { - packet.Payload = data[sentBytes : sentBytes+g.maxChunkSize] - sentBytes += g.maxChunkSize + packet.Payload = data[sentBytes : sentBytes+maxChunk] + sentBytes += maxChunk } if err := sendPacket(packet); err != nil { @@ -246,9 +206,9 @@ func (g *GoBackNConn) Recv() ([]byte, error) { msg *PacketData ) - g.recvTimeoutMu.RLock() + g.timeoutsMu.RLock() ticker := time.NewTimer(g.recvTimeout) - g.recvTimeoutMu.RUnlock() + g.timeoutsMu.RUnlock() defer ticker.Stop() for { @@ -276,21 +236,21 @@ func (g *GoBackNConn) start() { g.log.Debugf("Starting") pingTime := time.Duration(math.MaxInt64) - if g.pingTime != 0 { - pingTime = g.pingTime + if g.cfg.pingTime != 0 { + pingTime = g.cfg.pingTime } g.pingTicker = NewIntervalAwareForceTicker(pingTime) g.pingTicker.Resume() pongTime := time.Duration(math.MaxInt64) - if g.pongTime != 0 { - pongTime = g.pongTime + if g.cfg.pongTime != 0 { + pongTime = g.cfg.pongTime } g.pongTicker = NewIntervalAwareForceTicker(pongTime) - g.resendTicker = time.NewTicker(g.resendTimeout) + g.resendTicker = time.NewTicker(g.cfg.resendTimeout) g.wg.Add(1) go func() { @@ -382,7 +342,7 @@ func (g *GoBackNConn) sendPacket(ctx context.Context, msg Message) error { return fmt.Errorf("serialize error: %s", err) } - err = g.sendToStream(ctx, b) + err = g.cfg.sendToStream(ctx, b) if err != nil { return fmt.Errorf("error calling sendToStream: %s", err) } @@ -456,7 +416,7 @@ func (g *GoBackNConn) sendPacketsForever() error { for { // If the queue size is still less than N, we can // continue to add more packets to the queue. - if g.sendQueue.size() < g.n { + if g.sendQueue.size() < g.cfg.n { break } @@ -500,7 +460,7 @@ func (g *GoBackNConn) receivePacketsForever() error { // nolint:gocyclo default: } - b, err := g.recvFromStream(g.ctx) + b, err := g.cfg.recvFromStream(g.ctx) if err != nil { return fmt.Errorf("error receiving "+ "from recvFromStream: %s", err) @@ -537,7 +497,7 @@ func (g *GoBackNConn) receivePacketsForever() error { // nolint:gocyclo return err } - g.recvSeq = (g.recvSeq + 1) % g.s + g.recvSeq = (g.recvSeq + 1) % g.cfg.s // If the packet was a ping, then there is no // data to return to the above layer. @@ -567,7 +527,8 @@ func (g *GoBackNConn) receivePacketsForever() error { // nolint:gocyclo // If we recently sent a NACK for the same // sequence number then back off. if lastNackSeq == g.recvSeq && - time.Since(lastNackTime) < g.resendTimeout { + time.Since(lastNackTime) < + g.cfg.resendTimeout { continue } @@ -591,7 +552,7 @@ func (g *GoBackNConn) receivePacketsForever() error { // nolint:gocyclo case *PacketACK: gotValidACK := g.sendQueue.processACK(m.Seq) if gotValidACK { - g.resendTicker.Reset(g.resendTimeout) + g.resendTicker.Reset(g.cfg.resendTimeout) // Send a signal to indicate that new // ACKs have been received. diff --git a/gbn/gbn_server.go b/gbn/gbn_server.go index 6aea8156..68e3b5b7 100644 --- a/gbn/gbn_server.go +++ b/gbn/gbn_server.go @@ -14,13 +14,15 @@ import ( func NewServerConn(ctx context.Context, sendFunc sendBytesFunc, recvFunc recvBytesFunc, opts ...Option) (*GoBackNConn, error) { - conn := newGoBackNConn(ctx, sendFunc, recvFunc, true, DefaultN) + cfg := newConfig(sendFunc, recvFunc, DefaultN) // Apply functional options for _, o := range opts { - o(conn) + o(cfg) } + conn := newGoBackNConn(ctx, cfg, "server") + if err := conn.serverHandshake(); err != nil { if err := conn.Close(); err != nil { conn.log.Errorf("Error closing ServerConn: %v", err) @@ -61,7 +63,7 @@ func (g *GoBackNConn) serverHandshake() error { // nolint:gocyclo case <-recvNext: } - b, err := g.recvFromStream(g.ctx) + b, err := g.cfg.recvFromStream(g.ctx) if err != nil { errChan <- err return @@ -125,7 +127,7 @@ func (g *GoBackNConn) serverHandshake() error { // nolint:gocyclo return err } - if err = g.sendToStream(g.ctx, b); err != nil { + if err = g.cfg.sendToStream(g.ctx, b); err != nil { return err } @@ -141,7 +143,7 @@ func (g *GoBackNConn) serverHandshake() error { // nolint:gocyclo } select { - case <-time.After(g.handshakeTimeout): + case <-time.After(g.cfg.handshakeTimeout): g.log.Debugf("SYNCACK resendTimeout. Abort and wait " + "for client to re-initiate") continue diff --git a/gbn/options.go b/gbn/options.go index 2c24b746..2c1e8df1 100644 --- a/gbn/options.go +++ b/gbn/options.go @@ -2,14 +2,14 @@ package gbn import "time" -type Option func(conn *GoBackNConn) +type Option func(conn *config) // WithMaxSendSize is used to set the maximum payload size in bytes per packet. // If set and a large payload comes through then it will be split up into // multiple packets with payloads no larger than the given maximum size. // A size of zero will disable splitting. func WithMaxSendSize(size int) Option { - return func(conn *GoBackNConn) { + return func(conn *config) { conn.maxChunkSize = size } } @@ -17,7 +17,7 @@ func WithMaxSendSize(size int) Option { // WithTimeout is used to set the resend timeout. This is the time to wait // for ACKs before resending the queue. func WithTimeout(timeout time.Duration) Option { - return func(conn *GoBackNConn) { + return func(conn *config) { conn.resendTimeout = timeout } } @@ -26,7 +26,7 @@ func WithTimeout(timeout time.Duration) Option { // If the timeout is reached without response from the peer then the handshake // will be aborted and restarted. func WithHandshakeTimeout(timeout time.Duration) Option { - return func(conn *GoBackNConn) { + return func(conn *config) { conn.handshakeTimeout = timeout } } @@ -38,9 +38,8 @@ func WithHandshakeTimeout(timeout time.Duration) Option { // the connection will be closed if the other side does not respond within // time duration. func WithKeepalivePing(ping, pong time.Duration) Option { - return func(conn *GoBackNConn) { + return func(conn *config) { conn.pingTime = ping conn.pongTime = pong - conn.pongWait = make(chan struct{}, 1) } }