From 14ae6e5a28dff061a5acd4e8b8a98cbff0e4fd4c Mon Sep 17 00:00:00 2001 From: Alexey Kiselev Date: Sat, 28 Dec 2024 16:24:16 +0400 Subject: [PATCH] New network connection (#1547) * TaskGroup goroutine manager added in pkg/execution package. Tests on TaskGroup added. * Networking package with a new connection handler Session added. * Logger interface removed from networking package. Standard slog package is used instead. * WIP. Simple connection replaced with NetClient. NetClient usage moved into Universal client. Handshake proto updated to compatibility with Handshake interface from networking package. * Fixed NetClient closing issue. Configuration option to set KeepAliveInterval added to networking.Config. * Redundant log removed. * Move save int conversion to safecast lib. * Fix data race error in 'networking_test' package Implement 'io.Stringer' for 'Session' struct. Data race happens because 'clientHandler' mock in 'TestSessionTimeoutOnHandshake' test reads 'Session' structure at the same time as 'clientSession.Close' call. * Replace atomic.Uint32 with atomic.Bool and use CompareAndSwap there it's possible. Replace random delay with constan to make test not blink. Simplify assertion in test to make it stable. * Assertions added. Style fixed. * Simplified closing and close logic in NetClient. Added logs on handshake rejection to clarify the reason of rejections. Added and used function to configure Session with list of Slog attributes. * Prepare for new timer in Go 1.23 Co-authored-by: Nikolay Eskov * Move constant into function were it used. Proper error declaration. * Better way to prevent from running multiple receiveLoops. Shutdown lock replaced with sync.Once. * Better data emptyness checks. * Better read error handling. Co-authored-by: Nikolay Eskov * Use constructor. * Wrap heavy logging into log level checks. Fix data lock and data access order. * Session configuration accepts slog handler to set up logging. Discarding slog handler implemented and used instead of setting default slog logger. Checks on interval values added to Session constructor. * Close error channel on sending data successfully. Better error channel passing. Reset receiving buffer by deffering. * Better error handling while reading. Co-authored-by: Nikolay Eskov * Fine error assertions. * Fix blinking test. * Better configuration handling. Co-authored-by: Nikolay Eskov * Fixed blinking test TestCloseParentContext. Wait group added to wait for client to finish sending handshake. Better wait groups naming. * Better test workflow. Better wait group naming. * Fix deadlock in test by introducing wait group instead of sleep. * Internal sendPacket reimplemented using io.Reader. Data restoration function removed. Handler's OnReceive use io.Reader to pass received data. Tests updated. Mocks regenerated. * Itest network client handler updated. * Changed the way OnReceive passes the receiveBuffer. Test updated. --------- Co-authored-by: Nikolay Eskov --- .mockery.yaml | 12 + go.mod | 4 + go.sum | 6 + itests/clients/grpc_client.go | 5 + itests/clients/net_client.go | 218 +++++++++++++ itests/clients/node_client.go | 41 ++- itests/clients/universal_client.go | 9 +- itests/fixtures/base_fixtures.go | 4 +- itests/net/connection.go | 148 --------- itests/utilities/common.go | 7 +- pkg/execution/taskgroup.go | 118 +++++++ pkg/execution/taskgroup_test.go | 175 ++++++++++ pkg/networking/address.go | 23 ++ pkg/networking/configuration.go | 71 ++++ pkg/networking/handler.go | 15 + pkg/networking/logging.go | 18 + pkg/networking/mocks/handler.go | 138 ++++++++ pkg/networking/mocks/header.go | 238 ++++++++++++++ pkg/networking/mocks/protocol.go | 278 ++++++++++++++++ pkg/networking/network.go | 59 ++++ pkg/networking/protocol.go | 38 +++ pkg/networking/session.go | 407 +++++++++++++++++++++++ pkg/networking/session_test.go | 507 +++++++++++++++++++++++++++++ pkg/networking/timers.go | 41 +++ pkg/p2p/conn/conn.go | 6 +- pkg/proto/microblock.go | 21 +- pkg/proto/proto.go | 126 +++---- pkg/ride/math/math_test.go | 1 + pkg/util/common/util.go | 9 + 29 files changed, 2508 insertions(+), 235 deletions(-) create mode 100644 .mockery.yaml create mode 100644 itests/clients/net_client.go delete mode 100644 itests/net/connection.go create mode 100644 pkg/execution/taskgroup.go create mode 100644 pkg/execution/taskgroup_test.go create mode 100644 pkg/networking/address.go create mode 100644 pkg/networking/configuration.go create mode 100644 pkg/networking/handler.go create mode 100644 pkg/networking/logging.go create mode 100644 pkg/networking/mocks/handler.go create mode 100644 pkg/networking/mocks/header.go create mode 100644 pkg/networking/mocks/protocol.go create mode 100644 pkg/networking/network.go create mode 100644 pkg/networking/protocol.go create mode 100644 pkg/networking/session.go create mode 100644 pkg/networking/session_test.go create mode 100644 pkg/networking/timers.go diff --git a/.mockery.yaml b/.mockery.yaml new file mode 100644 index 000000000..d7430c317 --- /dev/null +++ b/.mockery.yaml @@ -0,0 +1,12 @@ +quiet: False +with-expecter: True +dir: "{{.InterfaceDir}}/mocks" +mockname: "Mock{{.InterfaceName}}" +filename: "{{.InterfaceName | snakecase}}.go" + +packages: + github.com/wavesplatform/gowaves/pkg/networking: + interfaces: + Header: + Protocol: + Handler: diff --git a/go.mod b/go.mod index 30383602d..57a24d401 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,7 @@ require ( filippo.io/edwards25519 v1.1.0 github.com/beevik/ntp v1.4.3 github.com/btcsuite/btcd/btcec/v2 v2.3.4 + github.com/ccoveille/go-safecast v1.2.0 github.com/cenkalti/backoff/v4 v4.3.0 github.com/cespare/xxhash/v2 v2.3.0 github.com/consensys/gnark v0.11.0 @@ -22,6 +23,7 @@ require ( github.com/influxdata/influxdb1-client v0.0.0-20200827194710-b269163b24ab github.com/jinzhu/copier v0.4.0 github.com/mr-tron/base58 v1.2.0 + github.com/neilotoole/slogt v1.1.0 github.com/ory/dockertest/v3 v3.11.0 github.com/phayes/freeport v0.0.0-20180830031419-95f893ade6f2 github.com/pkg/errors v0.9.1 @@ -42,6 +44,7 @@ require ( github.com/valyala/bytebufferpool v1.0.0 github.com/xenolf/lego v2.7.2+incompatible go.uber.org/atomic v1.11.0 + go.uber.org/goleak v1.3.0 go.uber.org/zap v1.27.0 golang.org/x/crypto v0.31.0 golang.org/x/exp v0.0.0-20240904232852-e7e105dedf7e @@ -98,6 +101,7 @@ require ( github.com/rs/zerolog v1.33.0 // indirect github.com/sirupsen/logrus v1.9.3 // indirect github.com/steakknife/hamming v0.0.0-20180906055917-c99c65617cd3 // indirect + github.com/stretchr/objx v0.5.2 // indirect github.com/tidwall/gjson v1.14.2 // indirect github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.0 // indirect diff --git a/go.sum b/go.sum index 9916f72cd..462d7e1ce 100644 --- a/go.sum +++ b/go.sum @@ -30,6 +30,8 @@ github.com/btcsuite/btcd/btcec/v2 v2.3.4 h1:3EJjcN70HCu/mwqlUsGK8GcNVyLVxFDlWurT github.com/btcsuite/btcd/btcec/v2 v2.3.4/go.mod h1:zYzJ8etWJQIv1Ogk7OzpWjowwOdXY1W/17j2MW85J04= github.com/btcsuite/btcd/chaincfg/chainhash v1.1.0 h1:59Kx4K6lzOW5w6nFlA0v5+lk/6sjybR934QNHSJZPTQ= github.com/btcsuite/btcd/chaincfg/chainhash v1.1.0/go.mod h1:7SFka0XMvUgj3hfZtydOrQY2mwhPclbT2snogU7SQQc= +github.com/ccoveille/go-safecast v1.2.0 h1:H4X7aosepsU1Mfk+098CTdKpsDH0cfYJ2RmwXFjgvfc= +github.com/ccoveille/go-safecast v1.2.0/go.mod h1:QqwNjxQ7DAqY0C721OIO9InMk9zCwcsO7tnRuHytad8= github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8= github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE= github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= @@ -199,6 +201,8 @@ github.com/mr-tron/base58 v1.2.0/go.mod h1:BinMc/sQntlIE1frQmRFPUoPA1Zkr8VRgBdjW github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= +github.com/neilotoole/slogt v1.1.0 h1:c7qE92sq+V0yvCuaxph+RQ2jOKL61c4hqS1Bv9W7FZE= +github.com/neilotoole/slogt v1.1.0/go.mod h1:RCrGXkPc/hYybNulqQrMHRtvlQ7F6NktNVLuLwk6V+w= github.com/nxadm/tail v1.4.4 h1:DQuhQpB1tVlglWS2hLQ5OV6B5r8aGxSrPc5Qo6uTN78= github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A= github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= @@ -282,6 +286,8 @@ github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+ github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= +github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= diff --git a/itests/clients/grpc_client.go b/itests/clients/grpc_client.go index bb68da291..2cf1eadbb 100644 --- a/itests/clients/grpc_client.go +++ b/itests/clients/grpc_client.go @@ -94,6 +94,11 @@ func (c *GRPCClient) GetAssetsInfo(t *testing.T, id []byte) *g.AssetInfoResponse return assetInfo } +func (c *GRPCClient) Close(t testing.TB) { + err := c.conn.Close() + assert.NoError(t, err, "failed to close GRPC connection to %s node", c.impl.String()) +} + func (c *GRPCClient) getBalance(t *testing.T, req *g.BalancesRequest) *g.BalanceResponse { ctx, cancel := context.WithTimeout(context.Background(), c.timeout) defer cancel() diff --git a/itests/clients/net_client.go b/itests/clients/net_client.go new file mode 100644 index 000000000..e6d984adb --- /dev/null +++ b/itests/clients/net_client.go @@ -0,0 +1,218 @@ +package clients + +import ( + "bytes" + "context" + "encoding/base64" + "io" + "log/slog" + "net" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/neilotoole/slogt" + "github.com/stretchr/testify/require" + + "github.com/wavesplatform/gowaves/itests/config" + "github.com/wavesplatform/gowaves/pkg/networking" + "github.com/wavesplatform/gowaves/pkg/proto" +) + +const ( + appName = "wavesL" + nonce = uint64(0) + networkTimeout = 3 * time.Second + pingInterval = 5 * time.Second +) + +type NetClient struct { + ctx context.Context + t testing.TB + impl Implementation + n *networking.Network + c *networking.Config + s *networking.Session + + closing atomic.Bool + closed sync.Once +} + +func NewNetClient( + ctx context.Context, t testing.TB, impl Implementation, port string, peers []proto.PeerInfo, +) *NetClient { + n := networking.NewNetwork() + p := newProtocol(t, nil) + h := newHandler(t, peers) + log := slogt.New(t) + conf := networking.NewConfig(p, h). + WithSlogHandler(log.Handler()). + WithWriteTimeout(networkTimeout). + WithKeepAliveInterval(pingInterval). + WithSlogAttributes(slog.String("suite", t.Name()), slog.String("impl", impl.String())) + + conn, err := net.Dial("tcp", config.DefaultIP+":"+port) + require.NoError(t, err, "failed to dial TCP to %s node", impl.String()) + + s, err := n.NewSession(ctx, conn, conf) + require.NoError(t, err, "failed to establish new session to %s node", impl.String()) + + cli := &NetClient{ctx: ctx, t: t, impl: impl, n: n, c: conf, s: s} + h.client = cli // Set client reference in handler. + return cli +} + +func (c *NetClient) SendHandshake() { + handshake := &proto.Handshake{ + AppName: appName, + Version: proto.ProtocolVersion(), + NodeName: "itest", + NodeNonce: nonce, + DeclaredAddr: proto.HandshakeTCPAddr{}, + Timestamp: proto.NewTimestampFromTime(time.Now()), + } + buf := bytes.NewBuffer(nil) + _, err := handshake.WriteTo(buf) + require.NoError(c.t, err, + "failed to marshal handshake to %s node at %q", c.impl.String(), c.s.RemoteAddr()) + _, err = c.s.Write(buf.Bytes()) + require.NoError(c.t, err, + "failed to send handshake to %s node at %q", c.impl.String(), c.s.RemoteAddr()) +} + +func (c *NetClient) SendMessage(m proto.Message) { + _, err := m.WriteTo(c.s) + require.NoError(c.t, err, "failed to send message to %s node at %q", c.impl.String(), c.s.RemoteAddr()) +} + +func (c *NetClient) Close() { + c.closed.Do(func() { + if c.closing.CompareAndSwap(false, true) { + c.t.Logf("Closing connection to %s node at %q", c.impl.String(), c.s.RemoteAddr().String()) + } + err := c.s.Close() + require.NoError(c.t, err, "failed to close session to %s node at %q", c.impl.String(), c.s.RemoteAddr()) + }) +} + +func (c *NetClient) reconnect() { + c.t.Logf("Reconnecting to %q", c.s.RemoteAddr().String()) + conn, err := net.Dial("tcp", c.s.RemoteAddr().String()) + require.NoError(c.t, err, "failed to dial TCP to %s node", c.impl.String()) + + s, err := c.n.NewSession(c.ctx, conn, c.c) + require.NoError(c.t, err, "failed to re-establish the session to %s node", c.impl.String()) + c.s = s + + c.SendHandshake() +} + +type protocol struct { + t testing.TB + dropLock sync.Mutex + drop map[proto.PeerMessageID]struct{} +} + +func newProtocol(t testing.TB, drop []proto.PeerMessageID) *protocol { + m := make(map[proto.PeerMessageID]struct{}) + for _, id := range drop { + m[id] = struct{}{} + } + return &protocol{t: t, drop: m} +} + +func (p *protocol) EmptyHandshake() networking.Handshake { + return &proto.Handshake{} +} + +func (p *protocol) EmptyHeader() networking.Header { + return &proto.Header{} +} + +func (p *protocol) Ping() ([]byte, error) { + msg := &proto.GetPeersMessage{} + return msg.MarshalBinary() +} + +func (p *protocol) IsAcceptableHandshake(h networking.Handshake) bool { + hs, ok := h.(*proto.Handshake) + if !ok { + return false + } + // Reject nodes with incorrect network bytes, unsupported protocol versions, + // or a zero nonce (indicating a self-connection). + if hs.AppName != appName || hs.Version.Cmp(proto.ProtocolVersion()) < 0 || hs.NodeNonce == 0 { + p.t.Logf("Unacceptable handshake:") + if hs.AppName != appName { + p.t.Logf("\tinvalid application name %q, expected %q", hs.AppName, appName) + } + if hs.Version.Cmp(proto.ProtocolVersion()) < 0 { + p.t.Logf("\tinvalid application version %q should be equal or more than %q", + hs.Version, proto.ProtocolVersion()) + } + if hs.NodeNonce == 0 { + p.t.Logf("\tinvalid node nonce %d", hs.NodeNonce) + } + return false + } + return true +} + +func (p *protocol) IsAcceptableMessage(h networking.Header) bool { + hdr, ok := h.(*proto.Header) + if !ok { + return false + } + p.dropLock.Lock() + defer p.dropLock.Unlock() + _, ok = p.drop[hdr.ContentID] + return !ok +} + +type handler struct { + peers []proto.PeerInfo + t testing.TB + client *NetClient +} + +func newHandler(t testing.TB, peers []proto.PeerInfo) *handler { + return &handler{t: t, peers: peers} +} + +func (h *handler) OnReceive(s *networking.Session, r io.Reader) { + data, err := io.ReadAll(r) + if err != nil { + h.t.Logf("Failed to read message from %q: %v", s.RemoteAddr(), err) + h.t.FailNow() + return + } + msg, err := proto.UnmarshalMessage(data) + if err != nil { // Fail test on unmarshal error. + h.t.Logf("Failed to unmarshal message from bytes: %q", base64.StdEncoding.EncodeToString(data)) + h.t.FailNow() + return + } + switch msg.(type) { // Only reply with peers on GetPeersMessage. + case *proto.GetPeersMessage: + h.t.Logf("Received GetPeersMessage from %q", s.RemoteAddr()) + rpl := &proto.PeersMessage{Peers: h.peers} + if _, sErr := rpl.WriteTo(s); sErr != nil { + h.t.Logf("Failed to send peers message: %v", sErr) + h.t.FailNow() + return + } + default: + } +} + +func (h *handler) OnHandshake(_ *networking.Session, _ networking.Handshake) { + h.t.Logf("Connection to %s node at %q was established", h.client.impl.String(), h.client.s.RemoteAddr()) +} + +func (h *handler) OnClose(s *networking.Session) { + h.t.Logf("Connection to %q was closed", s.RemoteAddr()) + if !h.client.closing.Load() && h.client != nil { + h.client.reconnect() + } +} diff --git a/itests/clients/node_client.go b/itests/clients/node_client.go index 7cf28b9c8..95a54d046 100644 --- a/itests/clients/node_client.go +++ b/itests/clients/node_client.go @@ -10,8 +10,10 @@ import ( "github.com/cenkalti/backoff/v4" "github.com/pkg/errors" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "golang.org/x/sync/errgroup" + "github.com/wavesplatform/gowaves/itests/config" d "github.com/wavesplatform/gowaves/itests/docker" "github.com/wavesplatform/gowaves/pkg/crypto" "github.com/wavesplatform/gowaves/pkg/proto" @@ -24,10 +26,19 @@ type NodesClients struct { ScalaClient *NodeUniversalClient } -func NewNodesClients(t *testing.T, goPorts, scalaPorts *d.PortConfig) *NodesClients { +func NewNodesClients(ctx context.Context, t *testing.T, goPorts, scalaPorts *d.PortConfig) *NodesClients { + sp, err := proto.NewPeerInfoFromString(config.DefaultIP + ":" + scalaPorts.BindPort) + require.NoError(t, err, "failed to create Scala peer info") + gp, err := proto.NewPeerInfoFromString(config.DefaultIP + ":" + goPorts.BindPort) + require.NoError(t, err, "failed to create Go peer info") + peers := []proto.PeerInfo{sp, gp} return &NodesClients{ - GoClient: NewNodeUniversalClient(t, NodeGo, goPorts.RESTAPIPort, goPorts.GRPCPort), - ScalaClient: NewNodeUniversalClient(t, NodeScala, scalaPorts.RESTAPIPort, scalaPorts.GRPCPort), + GoClient: NewNodeUniversalClient( + ctx, t, NodeGo, goPorts.RESTAPIPort, goPorts.GRPCPort, goPorts.BindPort, peers, + ), + ScalaClient: NewNodeUniversalClient( + ctx, t, NodeScala, scalaPorts.RESTAPIPort, scalaPorts.GRPCPort, scalaPorts.BindPort, peers, + ), } } @@ -236,7 +247,6 @@ func (c *NodesClients) SynchronizedWavesBalances( if err != nil { t.Logf("Errors while requesting balances: %v", err) } - t.Log("Entering loop") for { commonHeight := mostCommonHeight(sbs) toRetry := make([]proto.WavesAddress, 0, len(addresses)) @@ -273,6 +283,29 @@ func (c *NodesClients) SynchronizedWavesBalances( return r } +func (c *NodesClients) Handshake() { + c.GoClient.Connection.SendHandshake() + c.ScalaClient.Connection.SendHandshake() +} + +func (c *NodesClients) SendToNodes(t *testing.T, m proto.Message, scala bool) { + t.Logf("Sending message to Go node: %T", m) + c.GoClient.Connection.SendMessage(m) + t.Log("Message sent to Go node") + if scala { + t.Logf("Sending message to Scala node: %T", m) + c.ScalaClient.Connection.SendMessage(m) + t.Log("Message sent to Scala node") + } +} + +func (c *NodesClients) Close(t *testing.T) { + c.GoClient.GRPCClient.Close(t) + c.GoClient.Connection.Close() + c.ScalaClient.GRPCClient.Close(t) + c.ScalaClient.Connection.Close() +} + func (c *NodesClients) requestNodesAvailableBalances( ctx context.Context, address proto.WavesAddress, ) (addressedBalanceAtHeight, error) { diff --git a/itests/clients/universal_client.go b/itests/clients/universal_client.go index 9911e2015..32c20833a 100644 --- a/itests/clients/universal_client.go +++ b/itests/clients/universal_client.go @@ -1,19 +1,26 @@ package clients import ( + "context" "testing" + + "github.com/wavesplatform/gowaves/pkg/proto" ) type NodeUniversalClient struct { Implementation Implementation HTTPClient *HTTPClient GRPCClient *GRPCClient + Connection *NetClient } -func NewNodeUniversalClient(t *testing.T, impl Implementation, httpPort string, grpcPort string) *NodeUniversalClient { +func NewNodeUniversalClient( + ctx context.Context, t *testing.T, impl Implementation, httpPort, grpcPort, netPort string, peers []proto.PeerInfo, +) *NodeUniversalClient { return &NodeUniversalClient{ Implementation: impl, HTTPClient: NewHTTPClient(t, impl, httpPort), GRPCClient: NewGRPCClient(t, impl, grpcPort), + Connection: NewNetClient(ctx, t, impl, netPort, peers), } } diff --git a/itests/fixtures/base_fixtures.go b/itests/fixtures/base_fixtures.go index 23f359066..eb8f9a911 100644 --- a/itests/fixtures/base_fixtures.go +++ b/itests/fixtures/base_fixtures.go @@ -49,7 +49,8 @@ func (suite *BaseSuite) BaseSetup(options ...config.BlockchainOption) { suite.Require().NoError(ssErr, "couldn't start Scala node container") } - suite.Clients = clients.NewNodesClients(suite.T(), docker.GoNode().Ports(), docker.ScalaNode().Ports()) + suite.Clients = clients.NewNodesClients(suite.MainCtx, suite.T(), docker.GoNode().Ports(), docker.ScalaNode().Ports()) + suite.Clients.Handshake() } func (suite *BaseSuite) SetupSuite() { @@ -58,6 +59,7 @@ func (suite *BaseSuite) SetupSuite() { func (suite *BaseSuite) TearDownSuite() { suite.Clients.WaitForStateHashEquality(suite.T()) + suite.Clients.Close(suite.T()) suite.Docker.Finish(suite.Cancel) } diff --git a/itests/net/connection.go b/itests/net/connection.go deleted file mode 100644 index 7fafe65f1..000000000 --- a/itests/net/connection.go +++ /dev/null @@ -1,148 +0,0 @@ -package net - -import ( - "bufio" - stderrs "errors" - "net" - "testing" - "time" - - "github.com/cenkalti/backoff/v4" - "github.com/pkg/errors" - "github.com/stretchr/testify/assert" - - "github.com/wavesplatform/gowaves/itests/config" - d "github.com/wavesplatform/gowaves/itests/docker" - "github.com/wavesplatform/gowaves/pkg/proto" -) - -type OutgoingPeer struct { - conn net.Conn -} - -func NewConnection(declAddr proto.TCPAddr, address string, ver proto.Version, wavesNetwork string) (op *OutgoingPeer, err error) { - c, err := net.Dial("tcp", address) - if err != nil { - return nil, errors.Wrapf(err, "failed to connect to %s", address) - } - defer func() { - if err != nil { - if closeErr := c.Close(); closeErr != nil { - err = errors.Wrap(err, closeErr.Error()) - } - } - }() - handshake := proto.Handshake{ - AppName: wavesNetwork, - Version: ver, - NodeName: "itest", - NodeNonce: 0x0, - DeclaredAddr: proto.HandshakeTCPAddr(declAddr), - Timestamp: proto.NewTimestampFromTime(time.Now()), - } - - _, err = handshake.WriteTo(c) - if err != nil { - return nil, errors.Wrapf(err, "failed to send handshake to %s", address) - } - - _, err = handshake.ReadFrom(bufio.NewReader(c)) - if err != nil { - return nil, errors.Wrapf(err, "failed to read handshake from %s", address) - } - - return &OutgoingPeer{conn: c}, nil -} - -func (a *OutgoingPeer) SendMessage(m proto.Message) error { - b, err := m.MarshalBinary() - if err != nil { - return err - } - - _, err = a.conn.Write(b) - if err != nil { - return errors.Wrapf(err, "failed to send message") - } - return nil -} - -func (a *OutgoingPeer) Close() error { - return a.conn.Close() -} - -type NodeConnections struct { - scalaCon *OutgoingPeer - goCon *OutgoingPeer -} - -func NewNodeConnections(goPorts, scalaPorts *d.PortConfig) (NodeConnections, error) { - var connections NodeConnections - err := retry(1*time.Second, func() error { - var err error - connections, err = establishConnections(goPorts, scalaPorts) - return err - }) - return connections, err -} - -func establishConnections(goPorts, scalaPorts *d.PortConfig) (NodeConnections, error) { - goCon, err := NewConnection( - proto.TCPAddr{}, - config.DefaultIP+":"+goPorts.BindPort, - proto.ProtocolVersion(), "wavesL", - ) - if err != nil { - return NodeConnections{}, errors.Wrap(err, "failed to create connection to go node") - } - scalaCon, err := NewConnection( - proto.TCPAddr{}, - config.DefaultIP+":"+scalaPorts.BindPort, - proto.ProtocolVersion(), "wavesL", - ) - if err != nil { - if closeErr := goCon.Close(); closeErr != nil { - return NodeConnections{}, errors.Wrap(stderrs.Join(closeErr, err), - "failed to create connection to scala node and close go node connection") - } - return NodeConnections{}, errors.Wrap(err, "failed to create connection to scala node") - } - return NodeConnections{scalaCon: scalaCon, goCon: goCon}, nil -} - -func retry(timeout time.Duration, f func() error) error { - bo := backoff.NewExponentialBackOff() - bo.InitialInterval = 100 * time.Millisecond - bo.MaxInterval = 500 * time.Millisecond - bo.MaxElapsedTime = timeout - if err := backoff.Retry(f, bo); err != nil { - if bo.NextBackOff() == backoff.Stop { - return errors.Wrap(err, "reached retry deadline") - } - return err - } - return nil -} - -func (c *NodeConnections) SendToNodes(t *testing.T, m proto.Message, scala bool) { - t.Logf("Sending message to go node: %T", m) - err := c.goCon.SendMessage(m) - assert.NoError(t, err, "failed to send TransactionMessage to go node") - t.Log("Message sent to go node") - if scala { - t.Logf("Sending message to scala node: %T", m) - err = c.scalaCon.SendMessage(m) - assert.NoError(t, err, "failed to send TransactionMessage to scala node") - t.Log("Message sent to scala node") - } -} - -func (c *NodeConnections) Close(t *testing.T) { - t.Log("Closing connections") - err := c.goCon.Close() - assert.NoError(t, err, "failed to close go node connection") - - err = c.scalaCon.Close() - assert.NoError(t, err, "failed to close scala node connection") - t.Log("Connections closed") -} diff --git a/itests/utilities/common.go b/itests/utilities/common.go index 9c7a90a54..960823072 100644 --- a/itests/utilities/common.go +++ b/itests/utilities/common.go @@ -23,7 +23,6 @@ import ( "github.com/wavesplatform/gowaves/itests/config" f "github.com/wavesplatform/gowaves/itests/fixtures" - "github.com/wavesplatform/gowaves/itests/net" "github.com/wavesplatform/gowaves/pkg/client" "github.com/wavesplatform/gowaves/pkg/crypto" g "github.com/wavesplatform/gowaves/pkg/grpc/generated/waves/node/grpc" @@ -671,11 +670,7 @@ func SendAndWaitTransaction(suite *f.BaseSuite, tx proto.Transaction, scheme pro } scala := !waitForTx - connections, err := net.NewNodeConnections(suite.Docker.GoNode().Ports(), suite.Docker.ScalaNode().Ports()) - suite.Require().NoError(err, "failed to create new node connections") - defer connections.Close(suite.T()) - - connections.SendToNodes(suite.T(), txMsg, scala) + suite.Clients.SendToNodes(suite.T(), txMsg, scala) suite.T().Log("Tx msg was successfully send to nodes") suite.T().Log("Waiting for Tx appears in Blockchain") diff --git a/pkg/execution/taskgroup.go b/pkg/execution/taskgroup.go new file mode 100644 index 000000000..6550a257f --- /dev/null +++ b/pkg/execution/taskgroup.go @@ -0,0 +1,118 @@ +package execution + +import ( + "sync" + "sync/atomic" +) + +// A TaskGroup manages a collection of cooperating goroutines. Add new tasks to the group with the Run method. +// Call the Wait method to wait for the tasks to complete. +// A zero value is ready for use, but must not be copied after its first use. +// +// The group collects any errors returned by the tasks in the group. +// The first non-nil error reported by any execution and not filtered is returned from the Wait method. +type TaskGroup struct { + wg sync.WaitGroup // Counter for active goroutines. + + // active is true when the group is "active", meaning there has been at least one call to Run since the group + // was created or the last Wait. + // + // Together active and errLock work as a kind of resettable sync.Once. The fast path reads active and only + // acquires errLock if it discovers setup is needed. + active atomic.Bool + + errLock sync.Mutex // Guards the fields below. + err error // First captured error returned from Wait. + onError errorFunc // Called each time a task returns non-nil error. +} + +// NewTaskGroup constructs a new empty group with the specified error handler. +// See [TaskGroup.OnError] for a description of how errors are filtered. If handler is nil, no filtering is performed. +// Main properties of the TaskGroup are: +// - Cancel propagation. +// - Error propagation. +// - Waiting for all tasks to finish. +func NewTaskGroup(handler func(error) error) *TaskGroup { + return new(TaskGroup).OnError(handler) +} + +// OnError sets the error handler for TaskGroup. If handler is nil, +// the error handler is removed and errors are no longer filtered. Otherwise, each non-nil error reported by an +// execution running in g is passed to handler. +// +// Then handler is called with each reported error, and its result replaces the reported value. This permits handler to +// suppress or replace the error value selectively. +// +// Calls to handler are synchronized so that it is safe for handler to manipulate local data structures without +// additional locking. It is safe to call OnError while tasks are active in TaskGroup. +func (g *TaskGroup) OnError(handler func(error) error) *TaskGroup { + g.errLock.Lock() + defer g.errLock.Unlock() + g.onError = handler + return g +} + +// Run starts an [execute] function in a new goroutine in [TaskGroup]. The execution is not interrupted by TaskGroup, +// so the [execute] function should include the interruption logic. +func (g *TaskGroup) Run(execute func() error) { + g.wg.Add(1) + if !g.active.Load() { + g.activate() + } + go func() { + defer g.wg.Done() + if err := execute(); err != nil { + g.handleError(err) + } + }() +} + +// Wait blocks until all the goroutines currently active in the TaskGroup have returned, and all reported errors have +// been delivered to the handler. It returns the first non-nil error reported by any of the goroutines in the group and +// not filtered by an OnError handler. +// +// As with sync.WaitGroup, new tasks can be added to TaskGroup during a Wait call only if the TaskGroup contains at +// least one active execution when Wait is called and continuously thereafter until the last concurrent call to +// Run returns. +// +// Wait may be called from at most one goroutine at a time. After Wait has returned, the group is ready for reuse. +func (g *TaskGroup) Wait() error { + g.wg.Wait() + g.errLock.Lock() + defer g.errLock.Unlock() + + // If the group is still active, deactivate it now. + g.active.CompareAndSwap(true, false) + return g.err +} + +// activate resets the state of the group and marks it as "active". This is triggered by adding a goroutine to +// an empty group. +func (g *TaskGroup) activate() { + g.errLock.Lock() + defer g.errLock.Unlock() + if g.active.CompareAndSwap(false, true) { + g.err = nil + } +} + +// handleError synchronizes access to the error handler and captures the first non-nil error. +func (g *TaskGroup) handleError(err error) { + g.errLock.Lock() + defer g.errLock.Unlock() + e := g.onError.filter(err) + if e != nil && g.err == nil { + g.err = e // Capture the first unfiltered error. + } +} + +// An errorFunc is called by a group each time an execution reports an error. Its return value replaces the reported +// error, so the errorFunc can filter or suppress errors by modifying or discarding the input error. +type errorFunc func(error) error + +func (ef errorFunc) filter(err error) error { + if ef == nil { + return err + } + return ef(err) +} diff --git a/pkg/execution/taskgroup_test.go b/pkg/execution/taskgroup_test.go new file mode 100644 index 000000000..2678a0a04 --- /dev/null +++ b/pkg/execution/taskgroup_test.go @@ -0,0 +1,175 @@ +package execution_test + +import ( + "context" + "errors" + "math/rand/v2" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/goleak" + + "github.com/wavesplatform/gowaves/pkg/execution" +) + +func TestBasic(t *testing.T) { + defer goleak.VerifyNone(t) + + // Verify that the group works at all. + var g execution.TaskGroup + g.Run(work(25, nil)) + err := g.Wait() + require.NoError(t, err) + + // Verify that the group can be reused. + g.Run(work(50, nil)) + g.Run(work(75, nil)) + err = g.Wait() + require.NoError(t, err) + + // Verify that error is propagated without an error handler. + g.Run(work(50, errors.New("expected error"))) + err = g.Wait() + require.Error(t, err) +} + +func TestErrorsPropagation(t *testing.T) { + defer goleak.VerifyNone(t) + + expected := errors.New("expected error") + + var g execution.TaskGroup + g.Run(func() error { return expected }) + err := g.Wait() + require.ErrorIs(t, err, expected) + + g.OnError(func(error) error { return nil }) // discard all error + g.Run(func() error { return expected }) + err = g.Wait() + require.NoError(t, err) +} + +func TestCancelPropagation(t *testing.T) { + defer goleak.VerifyNone(t) + + const numTasks = 64 + + var errs []error + g := execution.NewTaskGroup(func(err error) error { + errs = append(errs, err) // Only collect non-nil errors and suppress them. + return nil + }) + + errOther := errors.New("something is wrong") + ctx, cancel := context.WithCancel(context.Background()) + var numOK int32 + for range numTasks { + g.Run(func() error { + d1 := randomDuration(2) + d2 := randomDuration(2) + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(d1): + return errOther + case <-time.After(d2): + atomic.AddInt32(&numOK, 1) // Count successful executions. + return nil + } + }) + } + cancel() + + err := g.Wait() + require.NoError(t, err) // No captured error is expected, should be suppressed. + + var numCanceled, numOther int + for _, e := range errs { + switch { + case errors.Is(e, context.Canceled): + numCanceled++ + case errors.Is(e, errOther): + numOther++ + default: + require.FailNowf(t, "No error is expected", "unexpected error: %v", e) + } + } + + total := int(numOK) + numCanceled + numOther + assert.Equal(t, numTasks, total) +} + +func TestWaitingForFinish(t *testing.T) { + defer goleak.VerifyNone(t) + + ctx, cancel := context.WithCancel(context.Background()) + + failure := errors.New("failure") + exec := func() error { + select { + case <-ctx.Done(): + return work(50, nil)() + case <-time.After(60 * time.Millisecond): + return failure + } + } + + var g execution.TaskGroup + g.Run(exec) + g.Run(exec) + g.Run(exec) + + cancel() + + err := g.Wait() + require.NoError(t, err) +} + +func TestRegression(t *testing.T) { + defer goleak.VerifyNone(t) + + t.Run("WaitRace", func(_ *testing.T) { + ready := make(chan struct{}) + var g execution.TaskGroup + g.Run(func() error { + <-ready + return nil + }) + + var wg sync.WaitGroup + wg.Add(2) + go func() { + defer wg.Done() + err := g.Wait() + require.NoError(t, err) + }() + go func() { + defer wg.Done() + err := g.Wait() + require.NoError(t, err) + }() + + close(ready) + wg.Wait() + }) + t.Run("WaitUnstarted", func(t *testing.T) { + require.NotPanics(t, func() { + var g execution.TaskGroup + err := g.Wait() + require.NoError(t, err) + }) + }) +} + +func randomDuration(n int64) time.Duration { + return time.Duration(rand.Int64N(n)) * time.Millisecond +} + +// work returns an execution function that does nothing for random number of ms with [n] ms upper limit and returns err. +func work(n int64, err error) func() error { + return func() error { time.Sleep(randomDuration(n)); return err } +} diff --git a/pkg/networking/address.go b/pkg/networking/address.go new file mode 100644 index 000000000..c1ebf9ca3 --- /dev/null +++ b/pkg/networking/address.go @@ -0,0 +1,23 @@ +package networking + +import ( + "fmt" + "net" +) + +type addressable interface { + LocalAddr() net.Addr + RemoteAddr() net.Addr +} + +type sessionAddress struct { + addr string +} + +func (*sessionAddress) Network() string { + return "session" +} + +func (a *sessionAddress) String() string { + return fmt.Sprintf("session:%s", a.addr) +} diff --git a/pkg/networking/configuration.go b/pkg/networking/configuration.go new file mode 100644 index 000000000..151bc8541 --- /dev/null +++ b/pkg/networking/configuration.go @@ -0,0 +1,71 @@ +package networking + +import ( + "log/slog" + "time" +) + +const ( + defaultKeepAliveInterval = 1 * time.Minute + defaultConnectionWriteTimeout = 15 * time.Second +) + +// Config allows to set some parameters of the [Conn] or it's underlying connection. +type Config struct { + slogHandler slog.Handler + protocol Protocol + handler Handler + keepAlive bool + keepAliveInterval time.Duration + connectionWriteTimeout time.Duration + attributes []any +} + +// NewConfig creates a new Config and sets required Protocol and Handler parameters. +// Other parameters are set to their default values. +func NewConfig(p Protocol, h Handler) *Config { + return &Config{ + protocol: p, + handler: h, + keepAlive: true, + keepAliveInterval: defaultKeepAliveInterval, + connectionWriteTimeout: defaultConnectionWriteTimeout, + attributes: nil, + } +} + +// WithSlogHandler sets the slog handler. +func (c *Config) WithSlogHandler(handler slog.Handler) *Config { + c.slogHandler = handler + return c +} + +// WithWriteTimeout sets connection write timeout attribute to the Config. +func (c *Config) WithWriteTimeout(timeout time.Duration) *Config { + c.connectionWriteTimeout = timeout + return c +} + +// WithSlogAttribute adds an attribute to the slice of attributes. +func (c *Config) WithSlogAttribute(attr slog.Attr) *Config { + c.attributes = append(c.attributes, attr) + return c +} + +// WithSlogAttributes adds given attributes to the slice of attributes. +func (c *Config) WithSlogAttributes(attrs ...slog.Attr) *Config { + for _, attr := range attrs { + c.attributes = append(c.attributes, attr) + } + return c +} + +func (c *Config) WithKeepAliveDisabled() *Config { + c.keepAlive = false + return c +} + +func (c *Config) WithKeepAliveInterval(interval time.Duration) *Config { + c.keepAliveInterval = interval + return c +} diff --git a/pkg/networking/handler.go b/pkg/networking/handler.go new file mode 100644 index 000000000..81b81cb39 --- /dev/null +++ b/pkg/networking/handler.go @@ -0,0 +1,15 @@ +package networking + +import "io" + +// Handler is an interface for handling new messages, handshakes and session close events. +type Handler interface { + // OnReceive fired on new message received. + OnReceive(*Session, io.Reader) + + // OnHandshake fired on new Handshake received. + OnHandshake(*Session, Handshake) + + // OnClose fired on Session closed. + OnClose(*Session) +} diff --git a/pkg/networking/logging.go b/pkg/networking/logging.go new file mode 100644 index 000000000..94338ab31 --- /dev/null +++ b/pkg/networking/logging.go @@ -0,0 +1,18 @@ +package networking + +import ( + "context" + "log/slog" +) + +// TODO: Remove this file and the handler when the default [slog.DiscardHandler] will be introduced in +// Go version 1.24. See https://go-review.googlesource.com/c/go/+/626486. + +// discardingHandler is a logger that discards all log messages. +// It is used when no slog handler is provided in the [Config]. +type discardingHandler struct{} + +func (h discardingHandler) Enabled(context.Context, slog.Level) bool { return false } +func (h discardingHandler) Handle(context.Context, slog.Record) error { return nil } +func (h discardingHandler) WithAttrs([]slog.Attr) slog.Handler { return h } +func (h discardingHandler) WithGroup(string) slog.Handler { return h } diff --git a/pkg/networking/mocks/handler.go b/pkg/networking/mocks/handler.go new file mode 100644 index 000000000..a11fdc547 --- /dev/null +++ b/pkg/networking/mocks/handler.go @@ -0,0 +1,138 @@ +// Code generated by mockery v2.50.1. DO NOT EDIT. + +package networking + +import ( + io "io" + + mock "github.com/stretchr/testify/mock" + networking "github.com/wavesplatform/gowaves/pkg/networking" +) + +// MockHandler is an autogenerated mock type for the Handler type +type MockHandler struct { + mock.Mock +} + +type MockHandler_Expecter struct { + mock *mock.Mock +} + +func (_m *MockHandler) EXPECT() *MockHandler_Expecter { + return &MockHandler_Expecter{mock: &_m.Mock} +} + +// OnClose provides a mock function with given fields: _a0 +func (_m *MockHandler) OnClose(_a0 *networking.Session) { + _m.Called(_a0) +} + +// MockHandler_OnClose_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'OnClose' +type MockHandler_OnClose_Call struct { + *mock.Call +} + +// OnClose is a helper method to define mock.On call +// - _a0 *networking.Session +func (_e *MockHandler_Expecter) OnClose(_a0 interface{}) *MockHandler_OnClose_Call { + return &MockHandler_OnClose_Call{Call: _e.mock.On("OnClose", _a0)} +} + +func (_c *MockHandler_OnClose_Call) Run(run func(_a0 *networking.Session)) *MockHandler_OnClose_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(*networking.Session)) + }) + return _c +} + +func (_c *MockHandler_OnClose_Call) Return() *MockHandler_OnClose_Call { + _c.Call.Return() + return _c +} + +func (_c *MockHandler_OnClose_Call) RunAndReturn(run func(*networking.Session)) *MockHandler_OnClose_Call { + _c.Run(run) + return _c +} + +// OnHandshake provides a mock function with given fields: _a0, _a1 +func (_m *MockHandler) OnHandshake(_a0 *networking.Session, _a1 networking.Handshake) { + _m.Called(_a0, _a1) +} + +// MockHandler_OnHandshake_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'OnHandshake' +type MockHandler_OnHandshake_Call struct { + *mock.Call +} + +// OnHandshake is a helper method to define mock.On call +// - _a0 *networking.Session +// - _a1 networking.Handshake +func (_e *MockHandler_Expecter) OnHandshake(_a0 interface{}, _a1 interface{}) *MockHandler_OnHandshake_Call { + return &MockHandler_OnHandshake_Call{Call: _e.mock.On("OnHandshake", _a0, _a1)} +} + +func (_c *MockHandler_OnHandshake_Call) Run(run func(_a0 *networking.Session, _a1 networking.Handshake)) *MockHandler_OnHandshake_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(*networking.Session), args[1].(networking.Handshake)) + }) + return _c +} + +func (_c *MockHandler_OnHandshake_Call) Return() *MockHandler_OnHandshake_Call { + _c.Call.Return() + return _c +} + +func (_c *MockHandler_OnHandshake_Call) RunAndReturn(run func(*networking.Session, networking.Handshake)) *MockHandler_OnHandshake_Call { + _c.Run(run) + return _c +} + +// OnReceive provides a mock function with given fields: _a0, _a1 +func (_m *MockHandler) OnReceive(_a0 *networking.Session, _a1 io.Reader) { + _m.Called(_a0, _a1) +} + +// MockHandler_OnReceive_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'OnReceive' +type MockHandler_OnReceive_Call struct { + *mock.Call +} + +// OnReceive is a helper method to define mock.On call +// - _a0 *networking.Session +// - _a1 io.Reader +func (_e *MockHandler_Expecter) OnReceive(_a0 interface{}, _a1 interface{}) *MockHandler_OnReceive_Call { + return &MockHandler_OnReceive_Call{Call: _e.mock.On("OnReceive", _a0, _a1)} +} + +func (_c *MockHandler_OnReceive_Call) Run(run func(_a0 *networking.Session, _a1 io.Reader)) *MockHandler_OnReceive_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(*networking.Session), args[1].(io.Reader)) + }) + return _c +} + +func (_c *MockHandler_OnReceive_Call) Return() *MockHandler_OnReceive_Call { + _c.Call.Return() + return _c +} + +func (_c *MockHandler_OnReceive_Call) RunAndReturn(run func(*networking.Session, io.Reader)) *MockHandler_OnReceive_Call { + _c.Run(run) + return _c +} + +// NewMockHandler creates a new instance of MockHandler. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockHandler(t interface { + mock.TestingT + Cleanup(func()) +}) *MockHandler { + mock := &MockHandler{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/pkg/networking/mocks/header.go b/pkg/networking/mocks/header.go new file mode 100644 index 000000000..eabade26a --- /dev/null +++ b/pkg/networking/mocks/header.go @@ -0,0 +1,238 @@ +// Code generated by mockery v2.50.1. DO NOT EDIT. + +package networking + +import ( + io "io" + + mock "github.com/stretchr/testify/mock" +) + +// MockHeader is an autogenerated mock type for the Header type +type MockHeader struct { + mock.Mock +} + +type MockHeader_Expecter struct { + mock *mock.Mock +} + +func (_m *MockHeader) EXPECT() *MockHeader_Expecter { + return &MockHeader_Expecter{mock: &_m.Mock} +} + +// HeaderLength provides a mock function with no fields +func (_m *MockHeader) HeaderLength() uint32 { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for HeaderLength") + } + + var r0 uint32 + if rf, ok := ret.Get(0).(func() uint32); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(uint32) + } + + return r0 +} + +// MockHeader_HeaderLength_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'HeaderLength' +type MockHeader_HeaderLength_Call struct { + *mock.Call +} + +// HeaderLength is a helper method to define mock.On call +func (_e *MockHeader_Expecter) HeaderLength() *MockHeader_HeaderLength_Call { + return &MockHeader_HeaderLength_Call{Call: _e.mock.On("HeaderLength")} +} + +func (_c *MockHeader_HeaderLength_Call) Run(run func()) *MockHeader_HeaderLength_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockHeader_HeaderLength_Call) Return(_a0 uint32) *MockHeader_HeaderLength_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockHeader_HeaderLength_Call) RunAndReturn(run func() uint32) *MockHeader_HeaderLength_Call { + _c.Call.Return(run) + return _c +} + +// PayloadLength provides a mock function with no fields +func (_m *MockHeader) PayloadLength() uint32 { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for PayloadLength") + } + + var r0 uint32 + if rf, ok := ret.Get(0).(func() uint32); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(uint32) + } + + return r0 +} + +// MockHeader_PayloadLength_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'PayloadLength' +type MockHeader_PayloadLength_Call struct { + *mock.Call +} + +// PayloadLength is a helper method to define mock.On call +func (_e *MockHeader_Expecter) PayloadLength() *MockHeader_PayloadLength_Call { + return &MockHeader_PayloadLength_Call{Call: _e.mock.On("PayloadLength")} +} + +func (_c *MockHeader_PayloadLength_Call) Run(run func()) *MockHeader_PayloadLength_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockHeader_PayloadLength_Call) Return(_a0 uint32) *MockHeader_PayloadLength_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockHeader_PayloadLength_Call) RunAndReturn(run func() uint32) *MockHeader_PayloadLength_Call { + _c.Call.Return(run) + return _c +} + +// ReadFrom provides a mock function with given fields: r +func (_m *MockHeader) ReadFrom(r io.Reader) (int64, error) { + ret := _m.Called(r) + + if len(ret) == 0 { + panic("no return value specified for ReadFrom") + } + + var r0 int64 + var r1 error + if rf, ok := ret.Get(0).(func(io.Reader) (int64, error)); ok { + return rf(r) + } + if rf, ok := ret.Get(0).(func(io.Reader) int64); ok { + r0 = rf(r) + } else { + r0 = ret.Get(0).(int64) + } + + if rf, ok := ret.Get(1).(func(io.Reader) error); ok { + r1 = rf(r) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockHeader_ReadFrom_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ReadFrom' +type MockHeader_ReadFrom_Call struct { + *mock.Call +} + +// ReadFrom is a helper method to define mock.On call +// - r io.Reader +func (_e *MockHeader_Expecter) ReadFrom(r interface{}) *MockHeader_ReadFrom_Call { + return &MockHeader_ReadFrom_Call{Call: _e.mock.On("ReadFrom", r)} +} + +func (_c *MockHeader_ReadFrom_Call) Run(run func(r io.Reader)) *MockHeader_ReadFrom_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(io.Reader)) + }) + return _c +} + +func (_c *MockHeader_ReadFrom_Call) Return(n int64, err error) *MockHeader_ReadFrom_Call { + _c.Call.Return(n, err) + return _c +} + +func (_c *MockHeader_ReadFrom_Call) RunAndReturn(run func(io.Reader) (int64, error)) *MockHeader_ReadFrom_Call { + _c.Call.Return(run) + return _c +} + +// WriteTo provides a mock function with given fields: w +func (_m *MockHeader) WriteTo(w io.Writer) (int64, error) { + ret := _m.Called(w) + + if len(ret) == 0 { + panic("no return value specified for WriteTo") + } + + var r0 int64 + var r1 error + if rf, ok := ret.Get(0).(func(io.Writer) (int64, error)); ok { + return rf(w) + } + if rf, ok := ret.Get(0).(func(io.Writer) int64); ok { + r0 = rf(w) + } else { + r0 = ret.Get(0).(int64) + } + + if rf, ok := ret.Get(1).(func(io.Writer) error); ok { + r1 = rf(w) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockHeader_WriteTo_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'WriteTo' +type MockHeader_WriteTo_Call struct { + *mock.Call +} + +// WriteTo is a helper method to define mock.On call +// - w io.Writer +func (_e *MockHeader_Expecter) WriteTo(w interface{}) *MockHeader_WriteTo_Call { + return &MockHeader_WriteTo_Call{Call: _e.mock.On("WriteTo", w)} +} + +func (_c *MockHeader_WriteTo_Call) Run(run func(w io.Writer)) *MockHeader_WriteTo_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(io.Writer)) + }) + return _c +} + +func (_c *MockHeader_WriteTo_Call) Return(n int64, err error) *MockHeader_WriteTo_Call { + _c.Call.Return(n, err) + return _c +} + +func (_c *MockHeader_WriteTo_Call) RunAndReturn(run func(io.Writer) (int64, error)) *MockHeader_WriteTo_Call { + _c.Call.Return(run) + return _c +} + +// NewMockHeader creates a new instance of MockHeader. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockHeader(t interface { + mock.TestingT + Cleanup(func()) +}) *MockHeader { + mock := &MockHeader{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/pkg/networking/mocks/protocol.go b/pkg/networking/mocks/protocol.go new file mode 100644 index 000000000..dc30f5d74 --- /dev/null +++ b/pkg/networking/mocks/protocol.go @@ -0,0 +1,278 @@ +// Code generated by mockery v2.50.1. DO NOT EDIT. + +package networking + +import ( + mock "github.com/stretchr/testify/mock" + networking "github.com/wavesplatform/gowaves/pkg/networking" +) + +// MockProtocol is an autogenerated mock type for the Protocol type +type MockProtocol struct { + mock.Mock +} + +type MockProtocol_Expecter struct { + mock *mock.Mock +} + +func (_m *MockProtocol) EXPECT() *MockProtocol_Expecter { + return &MockProtocol_Expecter{mock: &_m.Mock} +} + +// EmptyHandshake provides a mock function with no fields +func (_m *MockProtocol) EmptyHandshake() networking.Handshake { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for EmptyHandshake") + } + + var r0 networking.Handshake + if rf, ok := ret.Get(0).(func() networking.Handshake); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(networking.Handshake) + } + } + + return r0 +} + +// MockProtocol_EmptyHandshake_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'EmptyHandshake' +type MockProtocol_EmptyHandshake_Call struct { + *mock.Call +} + +// EmptyHandshake is a helper method to define mock.On call +func (_e *MockProtocol_Expecter) EmptyHandshake() *MockProtocol_EmptyHandshake_Call { + return &MockProtocol_EmptyHandshake_Call{Call: _e.mock.On("EmptyHandshake")} +} + +func (_c *MockProtocol_EmptyHandshake_Call) Run(run func()) *MockProtocol_EmptyHandshake_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockProtocol_EmptyHandshake_Call) Return(_a0 networking.Handshake) *MockProtocol_EmptyHandshake_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockProtocol_EmptyHandshake_Call) RunAndReturn(run func() networking.Handshake) *MockProtocol_EmptyHandshake_Call { + _c.Call.Return(run) + return _c +} + +// EmptyHeader provides a mock function with no fields +func (_m *MockProtocol) EmptyHeader() networking.Header { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for EmptyHeader") + } + + var r0 networking.Header + if rf, ok := ret.Get(0).(func() networking.Header); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(networking.Header) + } + } + + return r0 +} + +// MockProtocol_EmptyHeader_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'EmptyHeader' +type MockProtocol_EmptyHeader_Call struct { + *mock.Call +} + +// EmptyHeader is a helper method to define mock.On call +func (_e *MockProtocol_Expecter) EmptyHeader() *MockProtocol_EmptyHeader_Call { + return &MockProtocol_EmptyHeader_Call{Call: _e.mock.On("EmptyHeader")} +} + +func (_c *MockProtocol_EmptyHeader_Call) Run(run func()) *MockProtocol_EmptyHeader_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockProtocol_EmptyHeader_Call) Return(_a0 networking.Header) *MockProtocol_EmptyHeader_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockProtocol_EmptyHeader_Call) RunAndReturn(run func() networking.Header) *MockProtocol_EmptyHeader_Call { + _c.Call.Return(run) + return _c +} + +// IsAcceptableHandshake provides a mock function with given fields: _a0 +func (_m *MockProtocol) IsAcceptableHandshake(_a0 networking.Handshake) bool { + ret := _m.Called(_a0) + + if len(ret) == 0 { + panic("no return value specified for IsAcceptableHandshake") + } + + var r0 bool + if rf, ok := ret.Get(0).(func(networking.Handshake) bool); ok { + r0 = rf(_a0) + } else { + r0 = ret.Get(0).(bool) + } + + return r0 +} + +// MockProtocol_IsAcceptableHandshake_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'IsAcceptableHandshake' +type MockProtocol_IsAcceptableHandshake_Call struct { + *mock.Call +} + +// IsAcceptableHandshake is a helper method to define mock.On call +// - _a0 networking.Handshake +func (_e *MockProtocol_Expecter) IsAcceptableHandshake(_a0 interface{}) *MockProtocol_IsAcceptableHandshake_Call { + return &MockProtocol_IsAcceptableHandshake_Call{Call: _e.mock.On("IsAcceptableHandshake", _a0)} +} + +func (_c *MockProtocol_IsAcceptableHandshake_Call) Run(run func(_a0 networking.Handshake)) *MockProtocol_IsAcceptableHandshake_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(networking.Handshake)) + }) + return _c +} + +func (_c *MockProtocol_IsAcceptableHandshake_Call) Return(_a0 bool) *MockProtocol_IsAcceptableHandshake_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockProtocol_IsAcceptableHandshake_Call) RunAndReturn(run func(networking.Handshake) bool) *MockProtocol_IsAcceptableHandshake_Call { + _c.Call.Return(run) + return _c +} + +// IsAcceptableMessage provides a mock function with given fields: _a0 +func (_m *MockProtocol) IsAcceptableMessage(_a0 networking.Header) bool { + ret := _m.Called(_a0) + + if len(ret) == 0 { + panic("no return value specified for IsAcceptableMessage") + } + + var r0 bool + if rf, ok := ret.Get(0).(func(networking.Header) bool); ok { + r0 = rf(_a0) + } else { + r0 = ret.Get(0).(bool) + } + + return r0 +} + +// MockProtocol_IsAcceptableMessage_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'IsAcceptableMessage' +type MockProtocol_IsAcceptableMessage_Call struct { + *mock.Call +} + +// IsAcceptableMessage is a helper method to define mock.On call +// - _a0 networking.Header +func (_e *MockProtocol_Expecter) IsAcceptableMessage(_a0 interface{}) *MockProtocol_IsAcceptableMessage_Call { + return &MockProtocol_IsAcceptableMessage_Call{Call: _e.mock.On("IsAcceptableMessage", _a0)} +} + +func (_c *MockProtocol_IsAcceptableMessage_Call) Run(run func(_a0 networking.Header)) *MockProtocol_IsAcceptableMessage_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(networking.Header)) + }) + return _c +} + +func (_c *MockProtocol_IsAcceptableMessage_Call) Return(_a0 bool) *MockProtocol_IsAcceptableMessage_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockProtocol_IsAcceptableMessage_Call) RunAndReturn(run func(networking.Header) bool) *MockProtocol_IsAcceptableMessage_Call { + _c.Call.Return(run) + return _c +} + +// Ping provides a mock function with no fields +func (_m *MockProtocol) Ping() ([]byte, error) { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for Ping") + } + + var r0 []byte + var r1 error + if rf, ok := ret.Get(0).(func() ([]byte, error)); ok { + return rf() + } + if rf, ok := ret.Get(0).(func() []byte); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]byte) + } + } + + if rf, ok := ret.Get(1).(func() error); ok { + r1 = rf() + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockProtocol_Ping_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Ping' +type MockProtocol_Ping_Call struct { + *mock.Call +} + +// Ping is a helper method to define mock.On call +func (_e *MockProtocol_Expecter) Ping() *MockProtocol_Ping_Call { + return &MockProtocol_Ping_Call{Call: _e.mock.On("Ping")} +} + +func (_c *MockProtocol_Ping_Call) Run(run func()) *MockProtocol_Ping_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockProtocol_Ping_Call) Return(_a0 []byte, _a1 error) *MockProtocol_Ping_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockProtocol_Ping_Call) RunAndReturn(run func() ([]byte, error)) *MockProtocol_Ping_Call { + _c.Call.Return(run) + return _c +} + +// NewMockProtocol creates a new instance of MockProtocol. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockProtocol(t interface { + mock.TestingT + Cleanup(func()) +}) *MockProtocol { + mock := &MockProtocol{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/pkg/networking/network.go b/pkg/networking/network.go new file mode 100644 index 000000000..f145b4cb1 --- /dev/null +++ b/pkg/networking/network.go @@ -0,0 +1,59 @@ +package networking + +import ( + "context" + "errors" + "io" +) + +const Namespace = "NET" + +// TODO: Consider special Error type for all [networking] errors. +var ( + // ErrInvalidConfigurationNoProtocol is used when the configuration has no protocol. + ErrInvalidConfigurationNoProtocol = errors.New("invalid configuration: empty protocol") + + // ErrInvalidConfigurationNoHandler is used when the configuration has no handler. + ErrInvalidConfigurationNoHandler = errors.New("invalid configuration: empty handler") + + // ErrInvalidConfigurationNoKeepAliveInterval is used when the configuration has an invalid keep-alive interval. + ErrInvalidConfigurationNoKeepAliveInterval = errors.New("invalid configuration: invalid keep-alive interval value") + + // ErrInvalidConfigurationNoWriteTimeout is used when the configuration has an invalid write timeout. + ErrInvalidConfigurationNoWriteTimeout = errors.New("invalid configuration: invalid write timeout value") + + // ErrUnacceptableHandshake is used when the handshake is not accepted. + ErrUnacceptableHandshake = errors.New("handshake is not accepted") + + // ErrSessionShutdown is used if there is a shutdown during an operation. + ErrSessionShutdown = errors.New("session shutdown") + + // ErrConnectionWriteTimeout indicates that we hit the timeout writing to the underlying stream connection. + ErrConnectionWriteTimeout = errors.New("connection write timeout") + + // ErrKeepAliveProtocolFailure is used when the protocol failed to provide a keep-alive message. + ErrKeepAliveProtocolFailure = errors.New("protocol failed to provide a keep-alive message") + + // ErrConnectionClosedOnRead indicates that the connection was closed while reading. + ErrConnectionClosedOnRead = errors.New("connection closed on read") + + // ErrKeepAliveTimeout indicates that we failed to send keep-alive message and abandon a keep-alive loop. + ErrKeepAliveTimeout = errors.New("keep-alive loop timeout") + + // ErrEmptyTimerPool is raised on creation of Session with a nil pool. + ErrEmptyTimerPool = errors.New("empty timer pool") +) + +type Network struct { + tp *timerPool +} + +func NewNetwork() *Network { + return &Network{ + tp: newTimerPool(), + } +} + +func (n *Network) NewSession(ctx context.Context, conn io.ReadWriteCloser, conf *Config) (*Session, error) { + return newSession(ctx, conf, conn, n.tp) +} diff --git a/pkg/networking/protocol.go b/pkg/networking/protocol.go new file mode 100644 index 000000000..a5b97ec25 --- /dev/null +++ b/pkg/networking/protocol.go @@ -0,0 +1,38 @@ +package networking + +import "io" + +// Header is the interface that should be implemented by the real message header packet. +type Header interface { + io.ReaderFrom + io.WriterTo + HeaderLength() uint32 + PayloadLength() uint32 +} + +// Handshake is the common interface for a handshake packet. +type Handshake interface { + io.ReaderFrom + io.WriterTo +} + +// Protocol is the interface for the network protocol implementation. +// It provides the methods to create the handshake packet, message header, and ping packet. +// It also provides the methods to validate the handshake and message header packets. +type Protocol interface { + // EmptyHandshake returns the empty instance of the handshake packet. + EmptyHandshake() Handshake + + // EmptyHeader returns the empty instance of the message header. + EmptyHeader() Header + + // Ping return the actual ping packet. + Ping() ([]byte, error) + + // IsAcceptableHandshake checks the handshake is acceptable. + IsAcceptableHandshake(Handshake) bool + + // IsAcceptableMessage checks the message is acceptable by examining its header. + // If return false, the message will be discarded. + IsAcceptableMessage(Header) bool +} diff --git a/pkg/networking/session.go b/pkg/networking/session.go new file mode 100644 index 000000000..bdab4bdd4 --- /dev/null +++ b/pkg/networking/session.go @@ -0,0 +1,407 @@ +package networking + +import ( + "bufio" + "bytes" + "context" + "encoding/base64" + "errors" + "fmt" + "io" + "log/slog" + "net" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/wavesplatform/gowaves/pkg/execution" +) + +// Session is used to wrap a reliable ordered connection. +type Session struct { + g *execution.TaskGroup + ctx context.Context + cancel context.CancelFunc + + config *Config + logger *slog.Logger + tp *timerPool + + conn io.ReadWriteCloser // conn is the underlying connection + bufRead *bufio.Reader // buffered reader wrapped around the connection + + receiveLock sync.Mutex // Guards the receiveBuffer. + receiveBuffer *bytes.Buffer // receiveBuffer is used to store the incoming data. + + sendLock sync.Mutex // Guards the sendCh. + sendCh chan *sendPacket // sendCh is used to send data to the connection. + + receiving atomic.Bool // Indicates that receiveLoop already running. + established atomic.Bool // Indicates that incoming Handshake was successfully accepted. + shutdown sync.Once // shutdown is used to safely close the Session. +} + +// NewSession is used to construct a new session. +func newSession(ctx context.Context, config *Config, conn io.ReadWriteCloser, tp *timerPool) (*Session, error) { + if config.protocol == nil { + return nil, ErrInvalidConfigurationNoProtocol + } + if config.handler == nil { + return nil, ErrInvalidConfigurationNoHandler + } + if config.keepAlive && config.keepAliveInterval <= 0 { + return nil, ErrInvalidConfigurationNoKeepAliveInterval + } + if config.connectionWriteTimeout <= 0 { + return nil, ErrInvalidConfigurationNoWriteTimeout + } + if tp == nil { + return nil, ErrEmptyTimerPool + } + + sCtx, cancel := context.WithCancel(ctx) + s := &Session{ + g: execution.NewTaskGroup(suppressContextCancellationError), + ctx: sCtx, + cancel: cancel, + config: config, + tp: tp, + conn: conn, + bufRead: bufio.NewReader(conn), + sendCh: make(chan *sendPacket, 1), // TODO: Make the size of send channel configurable. + } + + slogHandler := config.slogHandler + if slogHandler == nil { + slogHandler = discardingHandler{} + } + + sa := [...]any{ + slog.String("namespace", Namespace), + slog.String("remote", s.RemoteAddr().String()), + } + attrs := append(sa[:], config.attributes...) + s.logger = slog.New(slogHandler).With(attrs...) + + s.g.Run(s.receiveLoop) + s.g.Run(s.sendLoop) + if s.config.keepAlive { + s.g.Run(s.keepaliveLoop) + } + + return s, nil +} + +func (s *Session) String() string { + return fmt.Sprintf("Session{local=%s,remote=%s}", s.LocalAddr(), s.RemoteAddr()) +} + +// LocalAddr returns the local network address. +func (s *Session) LocalAddr() net.Addr { + if a, ok := s.conn.(addressable); ok { + return a.LocalAddr() + } + return &sessionAddress{addr: "local"} +} + +// RemoteAddr returns the remote network address. +func (s *Session) RemoteAddr() net.Addr { + if a, ok := s.conn.(addressable); ok { + return a.RemoteAddr() + } + return &sessionAddress{addr: "remote"} +} + +// Close is used to close the session. It is safe to call Close multiple times from different goroutines, +// subsequent calls do nothing. +func (s *Session) Close() error { + var err error + s.shutdown.Do(func() { + s.logger.Debug("Closing session") + clErr := s.conn.Close() // Close the underlying connection. + if clErr != nil { + s.logger.Warn("Failed to close underlying connection", "error", clErr) + } + s.logger.Debug("Underlying connection closed") + + s.cancel() // Cancel the underlying context to interrupt the loops. + + s.logger.Debug("Waiting for loops to finish") + err = s.g.Wait() // Wait for loops to finish. + + err = errors.Join(err, clErr) // Combine loops finalization errors with connection close error. + + s.logger.Debug("Session closed", "error", err) + }) + return err +} + +// Write is used to write to the session. It is safe to call Write and/or Close concurrently. +func (s *Session) Write(msg []byte) (int, error) { + s.sendLock.Lock() + defer s.sendLock.Unlock() + + if err := s.waitForSend(msg); err != nil { + return 0, err + } + + return len(msg), nil +} + +// waitForSend waits to send a data, checking for a potential context cancellation. +func (s *Session) waitForSend(data []byte) error { + // Channel to receive an error from sendLoop goroutine. + // We are not closing this channel, it will be GCed when the session is closed. + errCh := make(chan error, 1) + + timer := s.tp.Get() + timer.Reset(s.config.connectionWriteTimeout) + defer s.tp.Put(timer) + + if s.logger.Enabled(s.ctx, slog.LevelDebug) { + s.logger.Debug("Sending data", "data", base64.StdEncoding.EncodeToString(data)) + } + select { + case s.sendCh <- newSendPacket(data, errCh): + s.logger.Debug("Data written into send channel") + case <-s.ctx.Done(): + s.logger.Debug("Session shutdown while sending data") + return ErrSessionShutdown + case <-timer.C: + s.logger.Debug("Connection write timeout while sending data") + return ErrConnectionWriteTimeout + } + + select { + case err, ok := <-errCh: + if !ok { + s.logger.Debug("Data sent successfully") + return nil // No error, data was sent successfully. + } + s.logger.Debug("Error sending data", "error", err) + return err + case <-s.ctx.Done(): + s.logger.Debug("Session shutdown while waiting send error") + return ErrSessionShutdown + case <-timer.C: + s.logger.Debug("Connection write timeout while waiting send error") + return ErrConnectionWriteTimeout + } +} + +// sendLoop is a long-running goroutine that sends data to the connection. +func (s *Session) sendLoop() error { + var dataBuf bytes.Buffer + for { + dataBuf.Reset() + + select { + case <-s.ctx.Done(): + s.logger.Debug("Exiting connection send loop") + return s.ctx.Err() + + case packet := <-s.sendCh: + packet.mu.Lock() + _, rErr := dataBuf.ReadFrom(packet.r) + if rErr != nil { + packet.mu.Unlock() + s.logger.Error("Failed to copy data into buffer", "error", rErr) + s.asyncSendErr(packet.err, rErr) + return rErr + } + if s.logger.Enabled(s.ctx, slog.LevelDebug) { + s.logger.Debug("Sending data to connection", + "data", base64.StdEncoding.EncodeToString(dataBuf.Bytes())) + } + packet.mu.Unlock() + + if dataBuf.Len() > 0 { + s.logger.Debug("Writing data into connection", "len", len(dataBuf.Bytes())) + _, err := s.conn.Write(dataBuf.Bytes()) // TODO: We are locking here, because no timeout set on connection itself. + if err != nil { + s.logger.Error("Failed to write data into connection", "error", err) + s.asyncSendErr(packet.err, err) + return err + } + s.logger.Debug("Data written into connection") + } + + // No error, close the channel. + close(packet.err) + } + } +} + +// receiveLoop continues to receive data until a fatal error is encountered or underlying connection is closed. +// Receive loop works after handshake and accepts only length-prepended messages. +func (s *Session) receiveLoop() error { + if !s.receiving.CompareAndSwap(false, true) { + return nil // Prevent running multiple receive loops. + } + for { + if err := s.receive(); err != nil { + if errors.Is(err, ErrConnectionClosedOnRead) { + s.config.handler.OnClose(s) + return nil // Exit normally on connection close. + } + return err + } + } +} + +func (s *Session) receive() error { + if s.established.Load() { + hdr := s.config.protocol.EmptyHeader() + return s.readMessage(hdr) + } + return s.readHandshake() +} + +func (s *Session) readHandshake() error { + s.logger.Debug("Reading handshake") + + hs := s.config.protocol.EmptyHandshake() + _, err := hs.ReadFrom(s.bufRead) + if err != nil { + if errors.Is(err, io.EOF) { + return ErrConnectionClosedOnRead + } + if errMsg := err.Error(); strings.Contains(errMsg, "closed") || + strings.Contains(errMsg, "reset by peer") { + return errors.Join(ErrConnectionClosedOnRead, err) // Wrap the error with ErrConnectionClosedOnRead. + } + s.logger.Error("Failed to read handshake from connection", "error", err) + return err + } + s.logger.Debug("Handshake successfully read") + + if !s.config.protocol.IsAcceptableHandshake(hs) { + s.logger.Error("Handshake is not acceptable") + return ErrUnacceptableHandshake + } + // Handshake is acceptable, we can switch the session into established state. + s.established.Store(true) + s.config.handler.OnHandshake(s, hs) + return nil +} + +func (s *Session) readMessage(hdr Header) error { + // Read the header + if _, err := hdr.ReadFrom(s.bufRead); err != nil { + if errors.Is(err, io.EOF) { + return ErrConnectionClosedOnRead + } + if errMsg := err.Error(); strings.Contains(errMsg, "closed") || + strings.Contains(errMsg, "reset by peer") || + strings.Contains(errMsg, "broken pipe") { // In Docker network built on top of pipe, we get this error on close. + return errors.Join(ErrConnectionClosedOnRead, err) // Wrap the error with ErrConnectionClosedOnRead. + } + s.logger.Error("Failed to read header", "error", err) + return err + } + if !s.config.protocol.IsAcceptableMessage(hdr) { + // We have to discard the remaining part of the message. + if _, err := io.CopyN(io.Discard, s.bufRead, int64(hdr.PayloadLength())); err != nil { + s.logger.Error("Failed to discard message", "error", err) + return err + } + } + // Read the new data + if err := s.readMessagePayload(hdr, s.bufRead); err != nil { + s.logger.Error("Failed to read message", "error", err) + return err + } + return nil +} + +func (s *Session) readMessagePayload(hdr Header, conn io.Reader) error { + // Wrap in a limited reader + s.logger.Debug("Reading message payload", "len", hdr.PayloadLength()) + conn = io.LimitReader(conn, int64(hdr.PayloadLength())) + + // Copy into buffer + s.receiveLock.Lock() + defer s.receiveLock.Unlock() + + if s.receiveBuffer == nil { + // Allocate the receiving buffer just-in-time to fit the full message. + s.receiveBuffer = bytes.NewBuffer(make([]byte, 0, hdr.HeaderLength()+hdr.PayloadLength())) + } + defer s.receiveBuffer.Reset() + _, err := hdr.WriteTo(s.receiveBuffer) + if err != nil { + s.logger.Error("Failed to write header to receiving buffer", "error", err) + return err + } + n, err := io.Copy(s.receiveBuffer, conn) + if err != nil { + s.logger.Error("Failed to copy payload to receiving buffer", "error", err) + return err + } + s.logger.Debug("Message payload successfully read", "len", n) + + // We lock the buffer from modification on the time of invocation of OnReceive handler. + // The slice of bytes passed into the handler is only valid for the duration of the handler invocation. + // So inside the handler better deserialize message or make a copy of the bytes. + if s.logger.Enabled(s.ctx, slog.LevelDebug) { + s.logger.Debug("Invoking OnReceive handler", "message", + base64.StdEncoding.EncodeToString(s.receiveBuffer.Bytes())) + } + s.config.handler.OnReceive(s, s.receiveBuffer) // Invoke OnReceive handler. + return nil +} + +// keepaliveLoop is a long-running goroutine that periodically sends a Ping message to keep the connection alive. +func (s *Session) keepaliveLoop() error { + for { + select { + case <-s.ctx.Done(): + return s.ctx.Err() + case <-time.After(s.config.keepAliveInterval): + // Get actual Ping message from Protocol. + p, err := s.config.protocol.Ping() + if err != nil { + s.logger.Error("Failed to get ping message", "error", err) + return ErrKeepAliveProtocolFailure + } + if sndErr := s.waitForSend(p); sndErr != nil { + if errors.Is(sndErr, ErrSessionShutdown) { + return nil // Exit normally on session termination. + } + s.logger.Error("Failed to send ping message", "error", err) + return ErrKeepAliveTimeout + } + } + } +} + +// sendPacket is used to send data. +type sendPacket struct { + mu sync.Mutex // Protects data from unsafe reads. + r io.Reader + err chan<- error +} + +func newSendPacket(data []byte, ch chan<- error) *sendPacket { + return &sendPacket{r: bytes.NewReader(data), err: ch} +} + +// asyncSendErr is used to try an async send of an error. +func (s *Session) asyncSendErr(ch chan<- error, err error) { + if ch == nil { + return + } + select { + case ch <- err: + s.logger.Debug("Error sent to channel", "error", err) + default: + } +} + +func suppressContextCancellationError(err error) error { + if errors.Is(err, context.Canceled) { + return nil + } + return err +} diff --git a/pkg/networking/session_test.go b/pkg/networking/session_test.go new file mode 100644 index 000000000..c30bfac81 --- /dev/null +++ b/pkg/networking/session_test.go @@ -0,0 +1,507 @@ +package networking_test + +import ( + "bytes" + "context" + "encoding/binary" + "errors" + "io" + "log/slog" + "sync" + "testing" + "time" + + "github.com/neilotoole/slogt" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "go.uber.org/goleak" + + "github.com/wavesplatform/gowaves/pkg/networking" + netmocks "github.com/wavesplatform/gowaves/pkg/networking/mocks" +) + +func TestSuccessfulSession(t *testing.T) { + defer goleak.VerifyNone(t) + + p := netmocks.NewMockProtocol(t) + p.On("EmptyHandshake").Return(&textHandshake{}, nil) + p.On("IsAcceptableHandshake", &textHandshake{v: "hello"}).Once().Return(true) + p.On("IsAcceptableHandshake", &textHandshake{v: "hello"}).Once().Return(true) + p.On("EmptyHeader").Return(&textHeader{}, nil) + p.On("IsAcceptableMessage", &textHeader{l: 2}).Once().Return(true) + p.On("IsAcceptableMessage", &textHeader{l: 13}).Once().Return(true) + + clientHandler := netmocks.NewMockHandler(t) + serverHandler := netmocks.NewMockHandler(t) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + clientConn, serverConn := testConnPipe() + net := networking.NewNetwork() + + cs, err := net.NewSession(ctx, clientConn, testConfig(t, p, clientHandler, "client")) + require.NoError(t, err) + ss, err := net.NewSession(ctx, serverConn, testConfig(t, p, serverHandler, "server")) + require.NoError(t, err) + + var sWG sync.WaitGroup + var cWG sync.WaitGroup + sWG.Add(1) + go func() { + sc1 := serverHandler.On("OnHandshake", ss, &textHandshake{v: "hello"}).Once().Return() + sc1.Run(func(_ mock.Arguments) { + n, wErr := ss.Write([]byte("hello")) + require.NoError(t, wErr) + assert.Equal(t, 5, n) + }) + sc2 := serverHandler.On("OnReceive", ss, bytes.NewBuffer(encodeMessage("Hello session"))). + Once().Return() + sc2.NotBefore(sc1). + Run(func(_ mock.Arguments) { + n, wErr := ss.Write(encodeMessage("Hi")) + require.NoError(t, wErr) + assert.Equal(t, 6, n) + sWG.Done() + }) + sWG.Wait() + }() + + cWG.Add(1) + cl1 := clientHandler.On("OnHandshake", cs, &textHandshake{v: "hello"}).Once().Return() + cl1.Run(func(_ mock.Arguments) { + n, wErr := cs.Write(encodeMessage("Hello session")) + require.NoError(t, wErr) + assert.Equal(t, 17, n) + }) + cl2 := clientHandler.On("OnReceive", cs, bytes.NewBuffer(encodeMessage("Hi"))).Once().Return() + cl2.NotBefore(cl1). + Run(func(_ mock.Arguments) { + cWG.Done() + }) + + n, err := cs.Write([]byte("hello")) // Send handshake to server. + require.NoError(t, err) + assert.Equal(t, 5, n) + + cWG.Wait() // Wait for server to finish. + + clientHandler.On("OnClose", cs).Return() + serverHandler.On("OnClose", ss).Return() + err = cs.Close() + assert.NoError(t, err) + err = ss.Close() + assert.NoError(t, err) +} + +func TestSessionTimeoutOnHandshake(t *testing.T) { + defer goleak.VerifyNone(t) + + mockProtocol := netmocks.NewMockProtocol(t) + mockProtocol.On("EmptyHandshake").Return(&textHandshake{}, nil) + + clientHandler := netmocks.NewMockHandler(t) + serverHandler := netmocks.NewMockHandler(t) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + clientConn, serverConn := testConnPipe() + net := networking.NewNetwork() + + clientSession, err := net.NewSession(ctx, clientConn, testConfig(t, mockProtocol, clientHandler, "client")) + require.NoError(t, err) + clientHandler.On("OnClose", clientSession).Return() + + serverSession, err := net.NewSession(ctx, serverConn, testConfig(t, mockProtocol, serverHandler, "server")) + require.NoError(t, err) + serverHandler.On("OnClose", serverSession).Return() + + // Lock + pc, ok := clientConn.(*pipeConn) + require.True(t, ok) + pc.writeBlocker.Lock() + + // Send handshake to server, but writing will block because the clientConn is locked. + n, err := clientSession.Write([]byte("hello")) + require.ErrorIs(t, err, networking.ErrConnectionWriteTimeout) + assert.Equal(t, 0, n) + + err = serverSession.Close() + assert.NoError(t, err) + + wg := new(sync.WaitGroup) + wg.Add(1) + go func() { + err = clientSession.Close() + assert.ErrorIs(t, err, io.ErrClosedPipe) + wg.Done() + }() + + // Unlock "timeout" and close client. + pc.writeBlocker.Unlock() + wg.Wait() +} + +func TestSessionTimeoutOnMessage(t *testing.T) { + defer goleak.VerifyNone(t) + + mockProtocol := netmocks.NewMockProtocol(t) + mockProtocol.On("EmptyHandshake").Return(&textHandshake{}, nil) + mockProtocol.On("IsAcceptableHandshake", &textHandshake{v: "hello"}).Once().Return(true) + mockProtocol.On("IsAcceptableHandshake", &textHandshake{v: "hello"}).Once().Return(true) + mockProtocol.On("EmptyHeader").Return(&textHeader{}, nil) + + clientHandler := netmocks.NewMockHandler(t) + serverHandler := netmocks.NewMockHandler(t) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + clientConn, serverConn := testConnPipe() + net := networking.NewNetwork() + + clientSession, err := net.NewSession(ctx, clientConn, testConfig(t, mockProtocol, clientHandler, "client")) + require.NoError(t, err) + serverSession, err := net.NewSession(ctx, serverConn, testConfig(t, mockProtocol, serverHandler, "server")) + require.NoError(t, err) + + pc, ok := clientConn.(*pipeConn) + require.True(t, ok) + + serverHandler.On("OnClose", serverSession).Return() + + clientWG := new(sync.WaitGroup) + clientWG.Add(1) // Wait for client to send Handshake to server. + + serverWG := new(sync.WaitGroup) + serverWG.Add(1) // Wait for server to reply with Handshake to client. + + pipeWG := new(sync.WaitGroup) + pipeWG.Add(1) // Wait for pipe to be locked. + + testWG := new(sync.WaitGroup) + testWG.Add(1) // Wait for client fail by timeout. + + serverHandler.On("OnClose", serverSession).Return() + sc1 := serverHandler.On("OnHandshake", serverSession, &textHandshake{v: "hello"}).Once().Return() + sc1.Run(func(_ mock.Arguments) { + clientWG.Wait() // Wait for client to send handshake, start replying with Handshake only after that. + n, wErr := serverSession.Write([]byte("hello")) + require.NoError(t, wErr) + assert.Equal(t, 5, n) + serverWG.Done() + }) + + clientHandler.On("OnClose", clientSession).Return() + cs1 := clientHandler.On("OnHandshake", clientSession, &textHandshake{v: "hello"}).Once().Return() + cs1.Run(func(_ mock.Arguments) { + pipeWG.Wait() // Wait for pipe to be locked. + // On receiving handshake from server, send the message back to server. + _, msgErr := clientSession.Write(encodeMessage("Hello session")) + require.ErrorIs(t, msgErr, networking.ErrConnectionWriteTimeout) + testWG.Done() + }) + + go func() { + serverWG.Wait() // Wait for finishing handshake before closing the pipe. + pc.writeBlocker.Lock() // Lock pipe after replying with the handshake from server. + pipeWG.Done() // Signal that pipe is locked. + }() + + // Send handshake to server. + n, err := clientSession.Write([]byte("hello")) + require.NoError(t, err) + assert.Equal(t, 5, n) + clientWG.Done() // Signal that handshake was sent to server. + + testWG.Wait() + + err = serverSession.Close() + assert.NoError(t, err) // Expect no error on the server side. + + pc.writeBlocker.Unlock() // Unlock the pipe. + + err = clientSession.Close() + assert.ErrorIs(t, err, io.ErrClosedPipe) // Expect error because connection to the server already closed. +} + +func TestDoubleClose(t *testing.T) { + defer goleak.VerifyNone(t) + + mockProtocol := netmocks.NewMockProtocol(t) + mockProtocol.On("EmptyHandshake").Return(&textHandshake{}, nil) + + clientHandler := netmocks.NewMockHandler(t) + serverHandler := netmocks.NewMockHandler(t) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + clientConn, serverConn := testConnPipe() + net := networking.NewNetwork() + + clientSession, err := net.NewSession(ctx, clientConn, testConfig(t, mockProtocol, clientHandler, "client")) + require.NoError(t, err) + serverSession, err := net.NewSession(ctx, serverConn, testConfig(t, mockProtocol, serverHandler, "server")) + require.NoError(t, err) + + clientHandler.On("OnClose", clientSession).Return() + serverHandler.On("OnClose", serverSession).Return() + + err = clientSession.Close() + assert.NoError(t, err) + err = clientSession.Close() + assert.NoError(t, err) + + err = serverSession.Close() + assert.NoError(t, err) + err = serverSession.Close() + assert.NoError(t, err) +} + +func TestOnClosedByOtherSide(t *testing.T) { + defer goleak.VerifyNone(t) + + mockProtocol := netmocks.NewMockProtocol(t) + mockProtocol.On("EmptyHandshake").Return(&textHandshake{}, nil) + mockProtocol.On("IsAcceptableHandshake", &textHandshake{v: "hello"}).Once().Return(true) + mockProtocol.On("IsAcceptableHandshake", &textHandshake{v: "hello"}).Once().Return(true) + mockProtocol.On("EmptyHeader").Return(&textHeader{}, nil) + + clientHandler := netmocks.NewMockHandler(t) + serverHandler := netmocks.NewMockHandler(t) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + clientConn, serverConn := testConnPipe() + net := networking.NewNetwork() + + clientSession, err := net.NewSession(ctx, clientConn, testConfig(t, mockProtocol, clientHandler, "client")) + require.NoError(t, err) + serverSession, err := net.NewSession(ctx, serverConn, testConfig(t, mockProtocol, serverHandler, "server")) + require.NoError(t, err) + + clientWG := new(sync.WaitGroup) + clientWG.Add(1) // Wait for client to send Handshake to server. + + serverWG := new(sync.WaitGroup) + serverWG.Add(1) // Wait for server to send Handshake to client, after that close the connection from server. + closeWG := new(sync.WaitGroup) + closeWG.Add(1) // Wait for server to close the connection. + + testWG := new(sync.WaitGroup) + testWG.Add(2) // Wait for both client and server to finish. + + serverHandler.On("OnClose", serverSession).Return() + sc1 := serverHandler.On("OnHandshake", serverSession, &textHandshake{v: "hello"}).Once().Return() + sc1.Run(func(_ mock.Arguments) { + clientWG.Wait() // Wait for client to send handshake, start replying with Handshake only after that. + n, wErr := serverSession.Write([]byte("hello")) + assert.NoError(t, wErr) + assert.Equal(t, 5, n) + go func() { + // Close server after client received the handshake from server. + serverWG.Wait() // Wait for client to receive server handshake. + clErr := serverSession.Close() + assert.NoError(t, clErr) + closeWG.Done() + testWG.Done() + }() + }) + + clientHandler.On("OnClose", clientSession).Return() + cs1 := clientHandler.On("OnHandshake", clientSession, &textHandshake{v: "hello"}).Once().Return() + cs1.Run(func(_ mock.Arguments) { + // On receiving handshake from server, signal to close the server. + serverWG.Done() + // Try to send message to server, but it will fail because server is already closed. + closeWG.Wait() // Wait for server to close. + _, msgErr := clientSession.Write(encodeMessage("Hello session")) + require.ErrorIs(t, msgErr, io.ErrClosedPipe) + testWG.Done() + }) + + // Send handshake to server. + n, err := clientSession.Write([]byte("hello")) + require.NoError(t, err) + assert.Equal(t, 5, n) + clientWG.Done() // Signal that handshake was sent to server. + + testWG.Wait() // Wait for client to finish. + err = clientSession.Close() + assert.ErrorIs(t, err, io.ErrClosedPipe) // Close reports the same error, because it was registered in the send loop. +} + +func TestCloseParentContext(t *testing.T) { + defer goleak.VerifyNone(t) + + mockProtocol := netmocks.NewMockProtocol(t) + mockProtocol.On("EmptyHandshake").Return(&textHandshake{}, nil) + mockProtocol.On("IsAcceptableHandshake", &textHandshake{v: "hello"}).Once().Return(true) + mockProtocol.On("IsAcceptableHandshake", &textHandshake{v: "hello"}).Once().Return(true) + mockProtocol.On("EmptyHeader").Return(&textHeader{}, nil) + + clientHandler := netmocks.NewMockHandler(t) + serverHandler := netmocks.NewMockHandler(t) + + ctx, cancel := context.WithCancel(context.Background()) + + clientConn, serverConn := testConnPipe() + net := networking.NewNetwork() + + clientSession, err := net.NewSession(ctx, clientConn, testConfig(t, mockProtocol, clientHandler, "client")) + require.NoError(t, err) + serverSession, err := net.NewSession(ctx, serverConn, testConfig(t, mockProtocol, serverHandler, "server")) + require.NoError(t, err) + + clientWG := new(sync.WaitGroup) + clientWG.Add(1) // Wait for client to send Handshake to server. + + serverWG := new(sync.WaitGroup) + serverWG.Add(1) // Wait for server to send Handshake to client, after that we will close the parent context. + + testWG := new(sync.WaitGroup) + testWG.Add(2) // Wait for both client and server to finish. + + serverHandler.On("OnClose", serverSession).Return() + sc1 := serverHandler.On("OnHandshake", serverSession, &textHandshake{v: "hello"}).Once().Return() + sc1.Run(func(_ mock.Arguments) { + clientWG.Wait() // Wait for client to send handshake, start replying with Handshake only after that. + n, wErr := serverSession.Write([]byte("hello")) + assert.NoError(t, wErr) + assert.Equal(t, 5, n) + go func() { + serverWG.Wait() // Wait for client to receive server handshake. + cancel() // Close parent context. + testWG.Done() + }() + }) + + clientHandler.On("OnClose", clientSession).Return() + + cs1 := clientHandler.On("OnHandshake", clientSession, &textHandshake{v: "hello"}).Once().Return() + cs1.Run(func(_ mock.Arguments) { + // On receiving handshake from server, signal to close the server. + serverWG.Done() + go func() { + // Try to send message to server, but it will fail because server is already closed. + time.Sleep(10 * time.Millisecond) // Wait for server to close. + _, msgErr := clientSession.Write(encodeMessage("Hello session")) + require.ErrorIs(t, msgErr, networking.ErrSessionShutdown) + testWG.Done() + }() + }) + + // Send handshake to server. + n, err := clientSession.Write([]byte("hello")) + require.NoError(t, err) + assert.Equal(t, 5, n) + clientWG.Done() // Signal that handshake was sent to server. + + testWG.Wait() // Wait for all interactions to finish. + + err = clientSession.Close() + assert.NoError(t, err) + err = serverSession.Close() + assert.NoError(t, err) +} + +func testConfig(t testing.TB, p networking.Protocol, h networking.Handler, direction string) *networking.Config { + log := slogt.New(t) + return networking.NewConfig(p, h). + WithSlogHandler(log.Handler()). + WithWriteTimeout(1 * time.Second). + WithKeepAliveDisabled(). + WithSlogAttribute(slog.String("direction", direction)) +} + +type pipeConn struct { + reader *io.PipeReader + writer *io.PipeWriter + writeBlocker sync.Mutex +} + +func (p *pipeConn) Read(b []byte) (int, error) { + return p.reader.Read(b) +} + +func (p *pipeConn) Write(b []byte) (int, error) { + p.writeBlocker.Lock() + defer p.writeBlocker.Unlock() + return p.writer.Write(b) +} + +func (p *pipeConn) Close() error { + rErr := p.reader.Close() + wErr := p.writer.Close() + return errors.Join(rErr, wErr) +} + +func testConnPipe() (io.ReadWriteCloser, io.ReadWriteCloser) { + read1, write1 := io.Pipe() + read2, write2 := io.Pipe() + conn1 := &pipeConn{reader: read1, writer: write2} + conn2 := &pipeConn{reader: read2, writer: write1} + return conn1, conn2 +} + +func encodeMessage(s string) []byte { + msg := make([]byte, 4+len(s)) + binary.BigEndian.PutUint32(msg[:4], uint32(len(s))) + copy(msg[4:], s) + return msg +} + +// We have to use the "real" handshake, not a mock, because we are reading or writing to a "real" piped connection. +type textHandshake struct { + v string +} + +func (h *textHandshake) ReadFrom(r io.Reader) (int64, error) { + buf := make([]byte, 5) + n, err := io.ReadFull(r, buf) + if err != nil { + return int64(n), err + } + h.v = string(buf[:n]) + return int64(n), nil +} + +func (h *textHandshake) WriteTo(w io.Writer) (int64, error) { + buf := []byte(h.v) + n, err := w.Write(buf) + return int64(n), err +} + +// We have to use the "real" header, not a mock, because we are reading or writing to a "real" piped connection. +type textHeader struct { + l uint32 +} + +func (h *textHeader) HeaderLength() uint32 { + return 4 +} + +func (h *textHeader) PayloadLength() uint32 { + return h.l +} + +func (h *textHeader) ReadFrom(r io.Reader) (int64, error) { + hdr := make([]byte, 4) + n, err := io.ReadFull(r, hdr) + if err != nil { + return int64(n), err + } + h.l = binary.BigEndian.Uint32(hdr) + return int64(n), nil +} + +func (h *textHeader) WriteTo(w io.Writer) (int64, error) { + buf := make([]byte, 4) + binary.BigEndian.PutUint32(buf, h.l) + n, err := w.Write(buf) + return int64(n), err +} diff --git a/pkg/networking/timers.go b/pkg/networking/timers.go new file mode 100644 index 000000000..9dd227c8a --- /dev/null +++ b/pkg/networking/timers.go @@ -0,0 +1,41 @@ +package networking + +import ( + "sync" + "time" +) + +type timerPool struct { + p *sync.Pool +} + +func newTimerPool() *timerPool { + const initialTimerInterval = time.Hour * 1e6 + return &timerPool{ + p: &sync.Pool{ + New: func() any { + timer := time.NewTimer(initialTimerInterval) + timer.Stop() + return timer + }, + }, + } +} + +func (p *timerPool) Get() *time.Timer { + t, ok := p.p.Get().(*time.Timer) + if !ok { + panic("invalid type of item in TimerPool") + } + return t +} + +func (p *timerPool) Put(t *time.Timer) { + if !t.Stop() { + select { + case <-t.C: + default: + } + } + p.p.Put(t) +} diff --git a/pkg/p2p/conn/conn.go b/pkg/p2p/conn/conn.go index b477bbe55..d2de99d95 100644 --- a/pkg/p2p/conn/conn.go +++ b/pkg/p2p/conn/conn.go @@ -124,12 +124,12 @@ func receiveFromRemote(conn deadlineReader, fromRemoteCh chan *bytebufferpool.By return errors.Wrap(err, "failed to read header") } // received too big message, probably it's an error - if l := int(header.HeaderLength() + header.PayloadLength); l > maxMessageSize { + if l := int(header.HeaderLength() + header.PayloadLength()); l > maxMessageSize { return errors.Errorf("received too long message, size=%d > max=%d", l, maxMessageSize) } if skip(header) { - if _, err := io.CopyN(io.Discard, reader, int64(header.PayloadLength)); err != nil { + if _, err := io.CopyN(io.Discard, reader, int64(header.PayloadLength())); err != nil { return errors.Wrap(err, "failed to skip payload") } continue @@ -142,7 +142,7 @@ func receiveFromRemote(conn deadlineReader, fromRemoteCh chan *bytebufferpool.By return errors.Wrap(err, "failed to write header into buff") } // then read all message to remaining buffer - if _, err := io.CopyN(b, reader, int64(header.PayloadLength)); err != nil { + if _, err := io.CopyN(b, reader, int64(header.PayloadLength())); err != nil { bytebufferpool.Put(b) return errors.Wrap(err, "failed to read payload into buffer") } diff --git a/pkg/proto/microblock.go b/pkg/proto/microblock.go index a02f6657c..ceb9c6048 100644 --- a/pkg/proto/microblock.go +++ b/pkg/proto/microblock.go @@ -11,6 +11,7 @@ import ( g "github.com/wavesplatform/gowaves/pkg/grpc/generated/waves" "github.com/wavesplatform/gowaves/pkg/libs/deserializer" "github.com/wavesplatform/gowaves/pkg/libs/serializer" + "github.com/wavesplatform/gowaves/pkg/util/common" ) const ( @@ -278,7 +279,7 @@ func (a *MicroBlockMessage) UnmarshalBinary(data []byte) error { if len(data) < crypto.SignatureSize*2+1 { return errors.New("invalid micro block size") } - b := make([]byte, len(data[:h.PayloadLength])) + b := make([]byte, len(data[:h.payloadLength])) copy(b, data) a.Body = b @@ -311,7 +312,7 @@ func (a *MicroBlockInvMessage) WriteTo(w io.Writer) (n int64, err error) { h.Length = maxHeaderLength + uint32(len(a.Body)) - 4 h.Magic = headerMagic h.ContentID = ContentIDInvMicroblock - h.PayloadLength = uint32(len(a.Body)) + h.payloadLength = common.SafeIntToUint32(len(a.Body)) dig, err := crypto.FastHash(a.Body) if err != nil { return 0, err @@ -351,10 +352,10 @@ func (a *MicroBlockRequestMessage) ReadFrom(_ io.Reader) (n int64, err error) { func (a *MicroBlockRequestMessage) WriteTo(w io.Writer) (int64, error) { var h Header - h.Length = maxHeaderLength + uint32(len(a.TotalBlockSig)) - 4 + h.Length = maxHeaderLength + common.SafeIntToUint32(len(a.TotalBlockSig)) - headerChecksumLen h.Magic = headerMagic h.ContentID = ContentIDMicroblockRequest - h.PayloadLength = uint32(len(a.TotalBlockSig)) + h.payloadLength = common.SafeIntToUint32(len(a.TotalBlockSig)) dig, err := crypto.FastHash(a.TotalBlockSig) if err != nil { return 0, err @@ -393,7 +394,7 @@ func (a *MicroBlockRequestMessage) UnmarshalBinary(data []byte) error { return errors.Errorf("wrong ContentID in Header: %x", h.ContentID) } data = data[17:] - body := make([]byte, h.PayloadLength) + body := make([]byte, h.payloadLength) copy(body, data) a.TotalBlockSig = body return nil @@ -517,8 +518,8 @@ func (a *MicroBlockInvMessage) UnmarshalBinary(data []byte) error { return errors.Errorf("wrong ContentID in Header: %x", h.ContentID) } data = data[17:] - body := make([]byte, h.PayloadLength) - copy(body, data[:h.PayloadLength]) + body := make([]byte, h.payloadLength) + copy(body, data[:h.payloadLength]) a.Body = body return nil } @@ -563,15 +564,15 @@ func (a *PBMicroBlockMessage) UnmarshalBinary(data []byte) error { if h.ContentID != ContentIDPBMicroBlock { return errors.Errorf("wrong ContentID in Header: %x", h.ContentID) } - if h.PayloadLength < crypto.DigestSize { + if h.payloadLength < crypto.DigestSize { return errors.New("PBMicroBlockMessage UnmarshalBinary: invalid data size") } data = data[17:] - if uint32(len(data)) < h.PayloadLength { + if common.SafeIntToUint32(len(data)) < h.payloadLength { return errors.New("invalid data size") } - mbBytes := data[:h.PayloadLength] + mbBytes := data[:h.payloadLength] a.MicroBlockBytes = make([]byte, len(mbBytes)) copy(a.MicroBlockBytes, mbBytes) return nil diff --git a/pkg/proto/proto.go b/pkg/proto/proto.go index 8f039d9b0..7c2fc2d1a 100644 --- a/pkg/proto/proto.go +++ b/pkg/proto/proto.go @@ -16,6 +16,7 @@ import ( "github.com/wavesplatform/gowaves/pkg/crypto" "github.com/wavesplatform/gowaves/pkg/util/collect_writes" + "github.com/wavesplatform/gowaves/pkg/util/common" ) const ( @@ -74,7 +75,7 @@ type Header struct { Length uint32 Magic uint32 ContentID PeerMessageID - PayloadLength uint32 + payloadLength uint32 PayloadChecksum [headerChecksumLen]byte } @@ -97,7 +98,7 @@ func (h *Header) WriteTo(w io.Writer) (int64, error) { } func (h *Header) HeaderLength() uint32 { - if h.PayloadLength > 0 { + if h.payloadLength > 0 { return headerSizeWithPayload } return headerSizeWithoutPayload @@ -133,8 +134,8 @@ func (h *Header) UnmarshalBinary(data []byte) error { return fmt.Errorf("received wrong magic: want %x, have %x", headerMagic, h.Magic) } h.ContentID = PeerMessageID(data[HeaderContentIDPosition]) - h.PayloadLength = binary.BigEndian.Uint32(data[9:headerSizeWithoutPayload]) - if h.PayloadLength > 0 { + h.payloadLength = binary.BigEndian.Uint32(data[9:headerSizeWithoutPayload]) + if h.payloadLength > 0 { if uint32(len(data)) < headerSizeWithPayload { return errors.New("Header UnmarshalBinary: invalid data size") } @@ -151,8 +152,8 @@ func (h *Header) Copy(data []byte) (int, error) { binary.BigEndian.PutUint32(data[0:4], h.Length) binary.BigEndian.PutUint32(data[4:8], headerMagic) data[HeaderContentIDPosition] = byte(h.ContentID) - binary.BigEndian.PutUint32(data[9:headerSizeWithoutPayload], h.PayloadLength) - if h.PayloadLength > 0 { + binary.BigEndian.PutUint32(data[9:headerSizeWithoutPayload], h.payloadLength) + if h.payloadLength > 0 { if len(data) < headerSizeWithPayload { return 0, errors.New("Header Copy: invalid data size") } @@ -162,6 +163,10 @@ func (h *Header) Copy(data []byte) (int, error) { return headerSizeWithoutPayload, nil } +func (h *Header) PayloadLength() uint32 { + return h.payloadLength +} + // Version represents the version of the protocol type Version struct { _ struct{} // this field disallows raw struct initialization @@ -490,10 +495,6 @@ func (a HandshakeTCPAddr) Network() string { return "tcp" } -func ParseHandshakeTCPAddr(s string) HandshakeTCPAddr { - return HandshakeTCPAddr(NewTCPAddrFromString(s)) -} - type U8String struct { S string } @@ -647,7 +648,7 @@ func (m *GetPeersMessage) MarshalBinary() ([]byte, error) { h.Length = maxHeaderLength - 8 h.Magic = headerMagic h.ContentID = ContentIDGetPeers - h.PayloadLength = 0 + h.payloadLength = 0 return h.MarshalBinary() } @@ -663,7 +664,7 @@ func (m *GetPeersMessage) UnmarshalBinary(b []byte) error { if header.ContentID != ContentIDGetPeers { return fmt.Errorf("getpeers message ContentID is unexpected: want %x have %x", ContentIDGetPeers, header.ContentID) } - if header.PayloadLength != 0 { + if header.payloadLength != 0 { return fmt.Errorf("getpeers message length is not zero") } @@ -969,10 +970,10 @@ func (m *PeersMessage) WriteTo(w io.Writer) (int64, error) { return n, err } - h.Length = maxHeaderLength + uint32(len(buf.Bytes())) - 4 + h.Length = maxHeaderLength + common.SafeIntToUint32(len(buf.Bytes())) - headerChecksumLen h.Magic = headerMagic h.ContentID = ContentIDPeers - h.PayloadLength = uint32(len(buf.Bytes())) + h.payloadLength = common.SafeIntToUint32(len(buf.Bytes())) dig, err := crypto.FastHash(buf.Bytes()) if err != nil { return 0, err @@ -1110,7 +1111,7 @@ func (m *GetSignaturesMessage) MarshalBinary() ([]byte, error) { h.Length = maxHeaderLength + uint32(len(body)) - 4 h.Magic = headerMagic h.ContentID = ContentIDGetSignatures - h.PayloadLength = uint32(len(body)) + h.payloadLength = common.SafeIntToUint32(len(body)) dig, err := crypto.FastHash(body) if err != nil { return nil, err @@ -1192,16 +1193,16 @@ type SignaturesMessage struct { // MarshalBinary encodes SignaturesMessage to binary form func (m *SignaturesMessage) MarshalBinary() ([]byte, error) { body := make([]byte, 4, 4+len(m.Signatures)) - binary.BigEndian.PutUint32(body[0:4], uint32(len(m.Signatures))) + binary.BigEndian.PutUint32(body[0:4], common.SafeIntToUint32(len(m.Signatures))) for _, b := range m.Signatures { body = append(body, b[:]...) } var h Header - h.Length = maxHeaderLength + uint32(len(body)) - 4 + h.Length = maxHeaderLength + common.SafeIntToUint32(len(body)) - headerChecksumLen h.Magic = headerMagic h.ContentID = ContentIDSignatures - h.PayloadLength = uint32(len(body)) + h.payloadLength = common.SafeIntToUint32(len(body)) dig, err := crypto.FastHash(body) if err != nil { return nil, err @@ -1285,10 +1286,10 @@ func (m *GetBlockMessage) MarshalBinary() ([]byte, error) { body := m.BlockID.Bytes() var h Header - h.Length = maxHeaderLength + uint32(len(body)) - 4 + h.Length = maxHeaderLength + common.SafeIntToUint32(len(body)) - headerChecksumLen h.Magic = headerMagic h.ContentID = ContentIDGetBlock - h.PayloadLength = uint32(len(body)) + h.payloadLength = common.SafeIntToUint32(len(body)) dig, err := crypto.FastHash(body) if err != nil { return nil, err @@ -1329,10 +1330,11 @@ func parsePacket(data []byte, ContentID PeerMessageID, name string, f func(paylo if h.ContentID != ContentID { return fmt.Errorf("%s: wrong ContentID in Header: %x", name, h.ContentID) } - if len(data) < int(17+h.PayloadLength) { - return fmt.Errorf("%s: expected data at least %d, found %d", name, 17+h.PayloadLength, len(data)) + if len(data) < int(headerSizeWithPayload+h.payloadLength) { + return fmt.Errorf("%s: expected data at least %d, found %d", + name, headerSizeWithPayload+h.payloadLength, len(data)) } - err := f(data[17 : 17+h.PayloadLength]) + err := f(data[headerSizeWithPayload : headerSizeWithPayload+h.payloadLength]) if err != nil { return errors.Wrapf(err, "%s payload error", name) } @@ -1380,10 +1382,10 @@ type BlockMessage struct { // MarshalBinary encodes BlockMessage to binary form func (m *BlockMessage) MarshalBinary() ([]byte, error) { var h Header - h.Length = maxHeaderLength + uint32(len(m.BlockBytes)) - 4 + h.Length = maxHeaderLength + common.SafeIntToUint32(len(m.BlockBytes)) - headerChecksumLen h.Magic = headerMagic h.ContentID = ContentIDBlock - h.PayloadLength = uint32(len(m.BlockBytes)) + h.payloadLength = common.SafeIntToUint32(len(m.BlockBytes)) dig, err := crypto.FastHash(m.BlockBytes) if err != nil { return nil, err @@ -1400,10 +1402,10 @@ func (m *BlockMessage) MarshalBinary() ([]byte, error) { func MakeHeader(contentID PeerMessageID, payload []byte) (Header, error) { var h Header - h.Length = maxHeaderLength + uint32(len(payload)) - 4 + h.Length = maxHeaderLength + common.SafeIntToUint32(len(payload)) - headerChecksumLen h.Magic = headerMagic h.ContentID = contentID - h.PayloadLength = uint32(len(payload)) + h.payloadLength = common.SafeIntToUint32(len(payload)) dig, err := crypto.FastHash(payload) if err != nil { return Header{}, err @@ -1425,11 +1427,11 @@ func (m *BlockMessage) UnmarshalBinary(data []byte) error { return fmt.Errorf("wrong ContentID in Header: %x", h.ContentID) } - if uint32(len(data)) < 17+h.PayloadLength { + if common.SafeIntToUint32(len(data)) < 17+h.payloadLength { return errors.New("BlockMessage UnmarshalBinary: invalid data size") } - m.BlockBytes = make([]byte, h.PayloadLength) - copy(m.BlockBytes, data[17:17+h.PayloadLength]) + m.BlockBytes = make([]byte, h.payloadLength) + copy(m.BlockBytes, data[17:17+h.payloadLength]) return nil } @@ -1463,10 +1465,10 @@ type ScoreMessage struct { // MarshalBinary encodes ScoreMessage to binary form func (m *ScoreMessage) MarshalBinary() ([]byte, error) { var h Header - h.Length = maxHeaderLength + uint32(len(m.Score)) - 4 + h.Length = maxHeaderLength + common.SafeIntToUint32(len(m.Score)) - headerChecksumLen h.Magic = headerMagic h.ContentID = ContentIDScore - h.PayloadLength = uint32(len(m.Score)) + h.payloadLength = common.SafeIntToUint32(len(m.Score)) dig, err := crypto.FastHash(m.Score) if err != nil { return nil, err @@ -1494,11 +1496,11 @@ func (m *ScoreMessage) UnmarshalBinary(data []byte) error { return fmt.Errorf("wrong ContentID in Header: %x", h.ContentID) } - if uint32(len(data)) < 17+h.PayloadLength { + if common.SafeIntToUint32(len(data)) < 17+h.payloadLength { return errors.New("invalid data size") } - m.Score = make([]byte, h.PayloadLength) - copy(m.Score, data[17:17+h.PayloadLength]) + m.Score = make([]byte, h.payloadLength) + copy(m.Score, data[17:17+h.payloadLength]) return nil } @@ -1530,10 +1532,10 @@ type TransactionMessage struct { // MarshalBinary encodes TransactionMessage to binary form func (m *TransactionMessage) MarshalBinary() ([]byte, error) { var h Header - h.Length = maxHeaderLength + uint32(len(m.Transaction)) - 4 + h.Length = maxHeaderLength + common.SafeIntToUint32(len(m.Transaction)) - headerChecksumLen h.Magic = headerMagic h.ContentID = ContentIDTransaction - h.PayloadLength = uint32(len(m.Transaction)) + h.payloadLength = common.SafeIntToUint32(len(m.Transaction)) dig, err := crypto.FastHash(m.Transaction) if err != nil { return nil, err @@ -1558,11 +1560,11 @@ func (m *TransactionMessage) UnmarshalBinary(data []byte) error { return fmt.Errorf("wrong ContentID in Header: %x", h.ContentID) } // TODO check max length - if uint32(len(data)) < maxHeaderLength+h.PayloadLength { + if common.SafeIntToUint32(len(data)) < maxHeaderLength+h.payloadLength { return errors.New("invalid data size") } - m.Transaction = make([]byte, h.PayloadLength) - copy(m.Transaction, data[maxHeaderLength:maxHeaderLength+h.PayloadLength]) + m.Transaction = make([]byte, h.payloadLength) + copy(m.Transaction, data[maxHeaderLength:maxHeaderLength+h.payloadLength]) dig, err := crypto.FastHash(m.Transaction) if err != nil { return err @@ -1618,10 +1620,10 @@ func (m *CheckPointMessage) MarshalBinary() ([]byte, error) { } var h Header - h.Length = maxHeaderLength + uint32(len(body)) - 4 + h.Length = maxHeaderLength + common.SafeIntToUint32(len(body)) - headerChecksumLen h.Magic = headerMagic h.ContentID = ContentIDCheckpoint - h.PayloadLength = uint32(len(body)) + h.payloadLength = common.SafeIntToUint32(len(body)) dig, err := crypto.FastHash(body) if err != nil { return nil, err @@ -1702,10 +1704,10 @@ type PBBlockMessage struct { // MarshalBinary encodes PBBlockMessage to binary form func (m *PBBlockMessage) MarshalBinary() ([]byte, error) { var h Header - h.Length = maxHeaderLength + uint32(len(m.PBBlockBytes)) - 4 + h.Length = maxHeaderLength + common.SafeIntToUint32(len(m.PBBlockBytes)) - headerChecksumLen h.Magic = headerMagic h.ContentID = ContentIDPBBlock - h.PayloadLength = uint32(len(m.PBBlockBytes)) + h.payloadLength = common.SafeIntToUint32(len(m.PBBlockBytes)) dig, err := crypto.FastHash(m.PBBlockBytes) if err != nil { return nil, err @@ -1733,11 +1735,11 @@ func (m *PBBlockMessage) UnmarshalBinary(data []byte) error { return fmt.Errorf("wrong ContentID in Header: %x", h.ContentID) } - m.PBBlockBytes = make([]byte, h.PayloadLength) - if uint32(len(data)) < 17+h.PayloadLength { + m.PBBlockBytes = make([]byte, h.payloadLength) + if common.SafeIntToUint32(len(data)) < 17+h.payloadLength { return errors.New("PBBlockMessage UnmarshalBinary: invalid data size") } - copy(m.PBBlockBytes, data[17:17+h.PayloadLength]) + copy(m.PBBlockBytes, data[17:17+h.payloadLength]) return nil } @@ -1771,10 +1773,10 @@ type PBTransactionMessage struct { // MarshalBinary encodes PBTransactionMessage to binary form func (m *PBTransactionMessage) MarshalBinary() ([]byte, error) { var h Header - h.Length = maxHeaderLength + uint32(len(m.Transaction)) - 4 + h.Length = maxHeaderLength + common.SafeIntToUint32(len(m.Transaction)) - headerChecksumLen h.Magic = headerMagic h.ContentID = ContentIDPBTransaction - h.PayloadLength = uint32(len(m.Transaction)) + h.payloadLength = common.SafeIntToUint32(len(m.Transaction)) dig, err := crypto.FastHash(m.Transaction) if err != nil { return nil, err @@ -1799,11 +1801,11 @@ func (m *PBTransactionMessage) UnmarshalBinary(data []byte) error { return fmt.Errorf("wrong ContentID in Header: %x", h.ContentID) } // TODO check max length - m.Transaction = make([]byte, h.PayloadLength) - if uint32(len(data)) < maxHeaderLength+h.PayloadLength { + m.Transaction = make([]byte, h.payloadLength) + if common.SafeIntToUint32(len(data)) < maxHeaderLength+h.payloadLength { return errors.New("PBTransactionMessage UnmarshalBinary: invalid data size") } - copy(m.Transaction, data[maxHeaderLength:maxHeaderLength+h.PayloadLength]) + copy(m.Transaction, data[maxHeaderLength:maxHeaderLength+h.payloadLength]) dig, err := crypto.FastHash(m.Transaction) if err != nil { return err @@ -1911,10 +1913,10 @@ func (m *GetBlockIdsMessage) MarshalBinary() ([]byte, error) { } var h Header - h.Length = maxHeaderLength + uint32(len(body)) - 4 + h.Length = maxHeaderLength + common.SafeIntToUint32(len(body)) - headerChecksumLen h.Magic = headerMagic h.ContentID = ContentIDGetBlockIDs - h.PayloadLength = uint32(len(body)) + h.payloadLength = common.SafeIntToUint32(len(body)) dig, err := crypto.FastHash(body) if err != nil { return nil, err @@ -2007,10 +2009,10 @@ func (m *BlockIdsMessage) MarshalBinary() ([]byte, error) { } var h Header - h.Length = maxHeaderLength + uint32(len(body)) - 4 + h.Length = maxHeaderLength + common.SafeIntToUint32(len(body)) - headerChecksumLen h.Magic = headerMagic h.ContentID = ContentIDBlockIDs - h.PayloadLength = uint32(len(body)) + h.payloadLength = common.SafeIntToUint32(len(body)) dig, err := crypto.FastHash(body) if err != nil { return nil, err @@ -2123,10 +2125,10 @@ type MiningLimits struct { func buildHeader(body []byte, messID PeerMessageID) (Header, error) { var h Header - h.Length = maxHeaderLength + uint32(len(body)) - headerChecksumLen + h.Length = maxHeaderLength + common.SafeIntToUint32(len(body)) - headerChecksumLen h.Magic = headerMagic h.ContentID = messID - h.PayloadLength = uint32(len(body)) + h.payloadLength = common.SafeIntToUint32(len(body)) dig, err := crypto.FastHash(body) if err != nil { return Header{}, err @@ -2221,8 +2223,8 @@ func (m *BlockSnapshotMessage) UnmarshalBinary(data []byte) error { if h.ContentID != ContentIDBlockSnapshot { return fmt.Errorf("wrong ContentID in Header: %x", h.ContentID) } - m.Bytes = make([]byte, h.PayloadLength) - copy(m.Bytes, data[maxHeaderLength:maxHeaderLength+h.PayloadLength]) + m.Bytes = make([]byte, h.payloadLength) + copy(m.Bytes, data[maxHeaderLength:maxHeaderLength+h.payloadLength]) return nil } @@ -2280,8 +2282,8 @@ func (m *MicroBlockSnapshotMessage) UnmarshalBinary(data []byte) error { if h.ContentID != ContentIDMicroBlockSnapshot { return fmt.Errorf("wrong ContentID in Header: %x", h.ContentID) } - m.Bytes = make([]byte, h.PayloadLength) - copy(m.Bytes, data[maxHeaderLength:maxHeaderLength+h.PayloadLength]) + m.Bytes = make([]byte, h.payloadLength) + copy(m.Bytes, data[maxHeaderLength:maxHeaderLength+h.payloadLength]) return nil } diff --git a/pkg/ride/math/math_test.go b/pkg/ride/math/math_test.go index e85da6d19..c237b0c66 100644 --- a/pkg/ride/math/math_test.go +++ b/pkg/ride/math/math_test.go @@ -21,6 +21,7 @@ func TestFraction(t *testing.T) { }{ {-6, 6301369, 100, false, -378082}, {6, 6301369, 100, false, 378082}, + {4445280, 1, 1440, false, 3087}, {6, 6301369, 0, true, 0}, } { r, err := Fraction(tc.value, tc.numerator, tc.denominator) diff --git a/pkg/util/common/util.go b/pkg/util/common/util.go index 0aff46164..9d4e854ba 100644 --- a/pkg/util/common/util.go +++ b/pkg/util/common/util.go @@ -12,6 +12,7 @@ import ( "strings" "time" + "github.com/ccoveille/go-safecast" "github.com/mr-tron/base58/base58" "github.com/pkg/errors" "golang.org/x/exp/constraints" @@ -244,3 +245,11 @@ func padBytes(p byte, bytes []byte) []byte { copy(r[1:], bytes) return r } + +func SafeIntToUint32(v int) uint32 { + r, err := safecast.ToUint32(v) + if err != nil { + panic(err) + } + return r +}