diff --git a/pkg/networking/handler.go b/pkg/networking/handler.go index 2f4f62587..81b81cb39 100644 --- a/pkg/networking/handler.go +++ b/pkg/networking/handler.go @@ -1,9 +1,11 @@ package networking +import "io" + // Handler is an interface for handling new messages, handshakes and session close events. type Handler interface { // OnReceive fired on new message received. - OnReceive(*Session, []byte) + OnReceive(*Session, io.Reader) // OnHandshake fired on new Handshake received. OnHandshake(*Session, Handshake) diff --git a/pkg/networking/mocks/handler.go b/pkg/networking/mocks/handler.go index d7ba29dd3..a11fdc547 100644 --- a/pkg/networking/mocks/handler.go +++ b/pkg/networking/mocks/handler.go @@ -1,8 +1,10 @@ -// Code generated by mockery v2.46.3. DO NOT EDIT. +// Code generated by mockery v2.50.1. DO NOT EDIT. package networking import ( + io "io" + mock "github.com/stretchr/testify/mock" networking "github.com/wavesplatform/gowaves/pkg/networking" ) @@ -49,7 +51,7 @@ func (_c *MockHandler_OnClose_Call) Return() *MockHandler_OnClose_Call { } func (_c *MockHandler_OnClose_Call) RunAndReturn(run func(*networking.Session)) *MockHandler_OnClose_Call { - _c.Call.Return(run) + _c.Run(run) return _c } @@ -83,12 +85,12 @@ func (_c *MockHandler_OnHandshake_Call) Return() *MockHandler_OnHandshake_Call { } func (_c *MockHandler_OnHandshake_Call) RunAndReturn(run func(*networking.Session, networking.Handshake)) *MockHandler_OnHandshake_Call { - _c.Call.Return(run) + _c.Run(run) return _c } // OnReceive provides a mock function with given fields: _a0, _a1 -func (_m *MockHandler) OnReceive(_a0 *networking.Session, _a1 []byte) { +func (_m *MockHandler) OnReceive(_a0 *networking.Session, _a1 io.Reader) { _m.Called(_a0, _a1) } @@ -99,14 +101,14 @@ type MockHandler_OnReceive_Call struct { // OnReceive is a helper method to define mock.On call // - _a0 *networking.Session -// - _a1 []byte +// - _a1 io.Reader func (_e *MockHandler_Expecter) OnReceive(_a0 interface{}, _a1 interface{}) *MockHandler_OnReceive_Call { return &MockHandler_OnReceive_Call{Call: _e.mock.On("OnReceive", _a0, _a1)} } -func (_c *MockHandler_OnReceive_Call) Run(run func(_a0 *networking.Session, _a1 []byte)) *MockHandler_OnReceive_Call { +func (_c *MockHandler_OnReceive_Call) Run(run func(_a0 *networking.Session, _a1 io.Reader)) *MockHandler_OnReceive_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(*networking.Session), args[1].([]byte)) + run(args[0].(*networking.Session), args[1].(io.Reader)) }) return _c } @@ -116,8 +118,8 @@ func (_c *MockHandler_OnReceive_Call) Return() *MockHandler_OnReceive_Call { return _c } -func (_c *MockHandler_OnReceive_Call) RunAndReturn(run func(*networking.Session, []byte)) *MockHandler_OnReceive_Call { - _c.Call.Return(run) +func (_c *MockHandler_OnReceive_Call) RunAndReturn(run func(*networking.Session, io.Reader)) *MockHandler_OnReceive_Call { + _c.Run(run) return _c } diff --git a/pkg/networking/mocks/header.go b/pkg/networking/mocks/header.go index 1de986214..eabade26a 100644 --- a/pkg/networking/mocks/header.go +++ b/pkg/networking/mocks/header.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.46.3. DO NOT EDIT. +// Code generated by mockery v2.50.1. DO NOT EDIT. package networking @@ -21,7 +21,7 @@ func (_m *MockHeader) EXPECT() *MockHeader_Expecter { return &MockHeader_Expecter{mock: &_m.Mock} } -// HeaderLength provides a mock function with given fields: +// HeaderLength provides a mock function with no fields func (_m *MockHeader) HeaderLength() uint32 { ret := _m.Called() @@ -66,7 +66,7 @@ func (_c *MockHeader_HeaderLength_Call) RunAndReturn(run func() uint32) *MockHea return _c } -// PayloadLength provides a mock function with given fields: +// PayloadLength provides a mock function with no fields func (_m *MockHeader) PayloadLength() uint32 { ret := _m.Called() diff --git a/pkg/networking/mocks/protocol.go b/pkg/networking/mocks/protocol.go index 19afa9cff..dc30f5d74 100644 --- a/pkg/networking/mocks/protocol.go +++ b/pkg/networking/mocks/protocol.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.46.3. DO NOT EDIT. +// Code generated by mockery v2.50.1. DO NOT EDIT. package networking @@ -20,7 +20,7 @@ func (_m *MockProtocol) EXPECT() *MockProtocol_Expecter { return &MockProtocol_Expecter{mock: &_m.Mock} } -// EmptyHandshake provides a mock function with given fields: +// EmptyHandshake provides a mock function with no fields func (_m *MockProtocol) EmptyHandshake() networking.Handshake { ret := _m.Called() @@ -67,7 +67,7 @@ func (_c *MockProtocol_EmptyHandshake_Call) RunAndReturn(run func() networking.H return _c } -// EmptyHeader provides a mock function with given fields: +// EmptyHeader provides a mock function with no fields func (_m *MockProtocol) EmptyHeader() networking.Header { ret := _m.Called() @@ -206,7 +206,7 @@ func (_c *MockProtocol_IsAcceptableMessage_Call) RunAndReturn(run func(networkin return _c } -// Ping provides a mock function with given fields: +// Ping provides a mock function with no fields func (_m *MockProtocol) Ping() ([]byte, error) { ret := _m.Called() diff --git a/pkg/networking/session.go b/pkg/networking/session.go index 2b7f6e72a..6cfc79e73 100644 --- a/pkg/networking/session.go +++ b/pkg/networking/session.go @@ -162,9 +162,8 @@ func (s *Session) waitForSend(data []byte) error { if s.logger.Enabled(s.ctx, slog.LevelDebug) { s.logger.Debug("Sending data", "data", base64.StdEncoding.EncodeToString(data)) } - ready := &sendPacket{data: data, err: errCh} select { - case s.sendCh <- ready: + case s.sendCh <- newSendPacket(data, errCh): s.logger.Debug("Data written into send channel") case <-s.ctx.Done(): s.logger.Debug("Session shutdown while sending data") @@ -174,24 +173,6 @@ func (s *Session) waitForSend(data []byte) error { return ErrConnectionWriteTimeout } - dataCopy := func() { - if len(data) == 0 { - return // An empty data is ignored. - } - - // In the event of session shutdown or connection write timeout, we need to prevent `send` from reading - // the body buffer after returning from this function since the caller may re-use the underlying array. - ready.mu.Lock() - defer ready.mu.Unlock() - - if ready.data == nil { - return // data was already copied in `send`. - } - newData := make([]byte, len(data)) - copy(newData, data) - ready.data = newData - } - select { case err, ok := <-errCh: if !ok { @@ -201,11 +182,9 @@ func (s *Session) waitForSend(data []byte) error { s.logger.Debug("Error sending data", "error", err) return err case <-s.ctx.Done(): - dataCopy() s.logger.Debug("Session shutdown while waiting send error") return ErrSessionShutdown case <-timer.C: - dataCopy() s.logger.Debug("Connection write timeout while waiting send error") return ErrConnectionWriteTimeout } @@ -224,22 +203,16 @@ func (s *Session) sendLoop() error { case packet := <-s.sendCh: packet.mu.Lock() + _, rErr := dataBuf.ReadFrom(packet.r) + if rErr != nil { + packet.mu.Unlock() + s.logger.Error("Failed to copy data into buffer", "error", rErr) + s.asyncSendErr(packet.err, rErr) + return rErr + } if s.logger.Enabled(s.ctx, slog.LevelDebug) { s.logger.Debug("Sending data to connection", - "data", base64.StdEncoding.EncodeToString(packet.data)) - } - if len(packet.data) != 0 { - // Copy the data into the buffer to avoid holding a mutex lock during the writing. - _, err := dataBuf.Write(packet.data) - if err != nil { - packet.data = nil - packet.mu.Unlock() - s.logger.Error("Failed to copy data into buffer", "error", err) - s.asyncSendErr(packet.err, err) - return err // TODO: Do we need to return here? - } - s.logger.Debug("Data copied into buffer") - packet.data = nil + "data", base64.StdEncoding.EncodeToString(dataBuf.Bytes())) } packet.mu.Unlock() @@ -375,7 +348,7 @@ func (s *Session) readMessagePayload(hdr Header, conn io.Reader) error { s.logger.Debug("Invoking OnReceive handler", "message", base64.StdEncoding.EncodeToString(s.receiveBuffer.Bytes())) } - s.config.handler.OnReceive(s, s.receiveBuffer.Bytes()) // Invoke OnReceive handler. + s.config.handler.OnReceive(s, bytes.NewReader(s.receiveBuffer.Bytes())) // Invoke OnReceive handler. return nil } @@ -405,9 +378,13 @@ func (s *Session) keepaliveLoop() error { // sendPacket is used to send data. type sendPacket struct { - mu sync.Mutex // Protects data from unsafe reads. - data []byte - err chan<- error + mu sync.Mutex // Protects data from unsafe reads. + r io.Reader + err chan<- error +} + +func newSendPacket(data []byte, ch chan<- error) *sendPacket { + return &sendPacket{r: bytes.NewReader(data), err: ch} } // asyncSendErr is used to try an async send of an error. diff --git a/pkg/networking/session_test.go b/pkg/networking/session_test.go index c39220027..745ee194b 100644 --- a/pkg/networking/session_test.go +++ b/pkg/networking/session_test.go @@ -1,6 +1,7 @@ package networking_test import ( + "bytes" "context" "encoding/binary" "errors" @@ -55,7 +56,8 @@ func TestSuccessfulSession(t *testing.T) { require.NoError(t, wErr) assert.Equal(t, 5, n) }) - sc2 := serverHandler.On("OnReceive", ss, encodeMessage("Hello session")).Once().Return() + sc2 := serverHandler.On("OnReceive", ss, bytes.NewReader(encodeMessage("Hello session"))). + Once().Return() sc2.NotBefore(sc1). Run(func(_ mock.Arguments) { n, wErr := ss.Write(encodeMessage("Hi")) @@ -73,7 +75,7 @@ func TestSuccessfulSession(t *testing.T) { require.NoError(t, wErr) assert.Equal(t, 17, n) }) - cl2 := clientHandler.On("OnReceive", cs, encodeMessage("Hi")).Once().Return() + cl2 := clientHandler.On("OnReceive", cs, bytes.NewReader(encodeMessage("Hi"))).Once().Return() cl2.NotBefore(cl1). Run(func(_ mock.Arguments) { cWG.Done()