From 46a177c45102a1050000071c8c41448ef4b7e531 Mon Sep 17 00:00:00 2001 From: mcamou Date: Thu, 28 Nov 2024 11:49:24 +0100 Subject: [PATCH] Fixes and tests --- node/options.go | 6 + node/oracle_node.go | 6 +- pkg/config/options.go | 4 + pkg/network/discover.go | 154 ++++++++++++---------- pkg/network/proxy.go | 61 +++++---- pkg/tests/integration/tracker_test.go | 180 ++++++++++++++++++++------ 6 files changed, 274 insertions(+), 137 deletions(-) diff --git a/node/options.go b/node/options.go index c6a6c826..63390b3d 100644 --- a/node/options.go +++ b/node/options.go @@ -22,6 +22,8 @@ type NodeOption struct { IsTelegramScraper bool IsWebScraper bool + IsProxy bool + Bootnodes []string RandomIdentity bool Services []func(ctx context.Context, node *OracleNode) @@ -87,6 +89,10 @@ var IsWebScraper = func(o *NodeOption) { o.IsWebScraper = true } +var IsProxy = func(o *NodeOption) { + o.IsProxy = true +} + func (a *NodeOption) Apply(opts ...Option) { for _, opt := range opts { opt(a) diff --git a/node/oracle_node.go b/node/oracle_node.go index 584a6eea..d6d11cbb 100644 --- a/node/oracle_node.go +++ b/node/oracle_node.go @@ -228,7 +228,11 @@ func (node *OracleNode) Start() (err error) { go p(node.Context, node) } - go myNetwork.Discover(node.Context, node.Options.Bootnodes, node.Host, node.DHT, node.Protocol) + protocols := []protocol.ID{node.Protocol} + if node.Options.IsProxy { + protocols = append(protocols, myNetwork.ProxyProtocol) + } + go myNetwork.Discover(node.Context, node.Options.Bootnodes, node.Host, node.DHT, protocols) nodeData := node.NodeTracker.GetNodeData(node.Host.ID().String()) if nodeData == nil { diff --git a/pkg/config/options.go b/pkg/config/options.go index f0e185da..2435b141 100644 --- a/pkg/config/options.go +++ b/pkg/config/options.go @@ -64,6 +64,10 @@ func InitOptions(cfg *AppConfig) ([]node.Option, *workers.WorkHandlerManager, *p masaNodeOptions = append(masaNodeOptions, node.IsWebScraper) } + if cfg.ProxyEnabled { + masaNodeOptions = append(masaNodeOptions, node.IsProxy) + } + workHandlerManager := workers.NewWorkHandlerManager(workerManagerOptions...) blockChainEventTracker := node.NewBlockChain() pubKeySub := &pubsub.PublicKeySubscriptionHandler{} diff --git a/pkg/network/discover.go b/pkg/network/discover.go index 18eb5e86..fbbe4271 100644 --- a/pkg/network/discover.go +++ b/pkg/network/discover.go @@ -22,26 +22,34 @@ import ( // It initializes discovery via the DHT and advertises this node. // It runs discovery in a loop with a ticker, re-advertising and finding new peers. // For each discovered peer, it checks if already connected, and if not, dials them. -func Discover(ctx context.Context, bootNodes []string, host host.Host, dht *dht.IpfsDHT, protocol protocol.ID) { +func Discover(ctx context.Context, bootNodes []string, host host.Host, dht *dht.IpfsDHT, protocols []protocol.ID) { var routingDiscovery *routing.RoutingDiscovery - protocolString := string(protocol) - logrus.Infof("[+] Discovering peers for protocol: %s", protocolString) - - routingDiscovery = routing.NewRoutingDiscovery(dht) - - // Advertise node right away, then it will re-advertise with each ticker interval - logrus.Infof("[+] Attempting to advertise protocol: %s", protocolString) - _, err := routingDiscovery.Advertise(ctx, protocolString) - if err != nil { - logrus.Debugf("[-] Failed to advertise protocol: %v", err) - } else { - logrus.Infof("[+] Successfully advertised protocol %s", protocolString) + + protos := []string{} + for _, p := range protocols { + protos = append(protos, string(p)) + } + + for _, p := range protos { + logrus.Infof("[+] Discovering peers for protocol: %s", p) + + routingDiscovery = routing.NewRoutingDiscovery(dht) + + // Advertise node right away, then it will re-advertise with each ticker interval + logrus.Infof("[+] Attempting to advertise protocol: %s", p) + _, err := routingDiscovery.Advertise(ctx, p) + if err != nil { + logrus.Debugf("[-] Failed to advertise protocol: %v", err) + } else { + logrus.Infof("[+] Successfully advertised protocol %s", p) + } } ticker := time.NewTicker(time.Minute * 1) defer ticker.Stop() var peerChan <-chan peer.AddrInfo + var err error for { select { @@ -56,66 +64,70 @@ func Discover(ctx context.Context, bootNodes []string, host host.Host, dht *dht. logrus.Debug("[-] Searching for other peers...") routingDiscovery = routing.NewRoutingDiscovery(dht) - // Advertise this node - logrus.Debugf("[-] Attempting to advertise protocol: %s", protocolString) - _, err := routingDiscovery.Advertise(ctx, protocolString) - if err != nil { - logrus.Debugf("[-] Failed to advertise protocol with error %v", err) - - // Network retry when connectivity is temporarily lost using NewExponentialBackOff - expBackOff := backoff.NewExponentialBackOff() - expBackOff.MaxElapsedTime = time.Second * 10 - err := backoff.Retry(func() error { - peerChan, err = routingDiscovery.FindPeers(ctx, protocolString) - return err - }, expBackOff) - if err != nil { - logrus.Warningf("[-] Retry failed to find peers: %v", err) - } - - } else { - logrus.Infof("[+] Successfully advertised protocol: %s", protocolString) - } + for _, protocolString := range protos { + // Advertise this node + logrus.Debugf("[-] Attempting to advertise protocol: %s", protocolString) + if _, err := routingDiscovery.Advertise(ctx, protocolString); err != nil { + logrus.Debugf("[-] Failed to advertise protocol with error %v", err) + + // Network retry when connectivity is temporarily lost using NewExponentialBackOff + expBackOff := backoff.NewExponentialBackOff() + expBackOff.MaxElapsedTime = time.Second * 10 + err = backoff.Retry(func() error { + peerChan, err = routingDiscovery.FindPeers(ctx, protocolString) + return err + }, expBackOff) + if err != nil { + logrus.Warningf("[-] Retry failed to find peers: %v", err) + } - // Use the routing discovery to find peers. - peerChan, err = routingDiscovery.FindPeers(ctx, protocolString) - if err != nil { - logrus.Errorf("[-] Failed to find peers: %v", err) - } else { - logrus.Debug("[+] Successfully started finding peers") - } - select { - case availPeer, ok := <-peerChan: - if !ok { - logrus.Info("[+] Peer channel closed, restarting discovery") - break - } - // validating proper peers to connect to - availPeerAddrInfo := peer.AddrInfo{ - ID: availPeer.ID, - Addrs: availPeer.Addrs, + } else { + logrus.Infof("[+] Successfully advertised protocol: %s", protocolString) } - if availPeerAddrInfo.ID == host.ID() { - logrus.Debugf("Skipping connect to self: %s", availPeerAddrInfo.ID.String()) - continue + + // Use the routing discovery to find peers. + peerChan, err = routingDiscovery.FindPeers(ctx, protocolString) + if err != nil { + logrus.Errorf("[-] Failed to find peers: %v", err) + } else { + logrus.Debug("[+] Successfully started finding peers") } - if len(availPeerAddrInfo.Addrs) == 0 { - for _, bn := range bootNodes { - bootNode := strings.Split(bn, "/")[len(strings.Split(bn, "/"))-1] - if availPeerAddrInfo.ID.String() != bootNode { - logrus.Warningf("Skipping connect to non bootnode peer with no multiaddress: %s", availPeerAddrInfo.ID.String()) - continue + + select { + case availPeer, ok := <-peerChan: + if !ok { + logrus.Info("[+] Peer channel closed, restarting discovery") + break + } + // validating proper peers to connect to + availPeerAddrInfo := peer.AddrInfo{ + ID: availPeer.ID, + Addrs: availPeer.Addrs, + } + if availPeerAddrInfo.ID == host.ID() { + logrus.Debugf("Skipping connect to self: %s", availPeerAddrInfo.ID.String()) + continue + } + if len(availPeerAddrInfo.Addrs) == 0 { + for _, bn := range bootNodes { + bootNode := strings.Split(bn, "/")[len(strings.Split(bn, "/"))-1] + if availPeerAddrInfo.ID.String() != bootNode { + logrus.Warningf("Skipping connect to non bootnode peer with no multiaddress: %s", availPeerAddrInfo.ID.String()) + continue + } } } - } - logrus.Infof("[+] Available Peer: %s", availPeer.String()) - - if host.Network().Connectedness(availPeer.ID) != network.Connected { - if isConnectedToBootnode(host, bootNodes) { - _, err := host.Network().DialPeer(ctx, availPeer.ID) - if err != nil { - logrus.Warningf("[-] Failed to connect to peer %s, will retry...", availPeer.ID.String()) - continue + logrus.Infof("[+] Available Peer: %s", availPeer.String()) + + if host.Network().Connectedness(availPeer.ID) != network.Connected { + if isConnectedToBootnode(host, bootNodes) { + _, err := host.Network().DialPeer(ctx, availPeer.ID) + if err != nil { + logrus.Warningf("[-] Failed to connect to peer %s, will retry...", availPeer.ID.String()) + continue + } else { + logrus.Infof("[+] Connected to peer %s", availPeer.ID.String()) + } } else { for _, bn := range bootNodes { if len(bn) > 0 { @@ -125,10 +137,10 @@ func Discover(ctx context.Context, bootNodes []string, host host.Host, dht *dht. } } } + case <-ctx.Done(): + logrus.Info("[-] Stopping peer discovery") + return } - case <-ctx.Done(): - logrus.Info("[-] Stopping peer discovery") - return } } } diff --git a/pkg/network/proxy.go b/pkg/network/proxy.go index 2ac2ead9..9b82a636 100644 --- a/pkg/network/proxy.go +++ b/pkg/network/proxy.go @@ -2,7 +2,6 @@ package network import ( "context" - "crypto/tls" "fmt" "io" "net" @@ -39,7 +38,7 @@ type Proxy struct { targetPort uint16 } -const proxyProtocol = "/connect-proxy/1.0.0" +const ProxyProtocol = "/masa/connect-proxy/0.0.1" func NewProxy(host host.Host, listenAddr string, listenPort uint16, targetPort uint16) (*Proxy, error) { addr := net.ParseIP(listenAddr) @@ -52,13 +51,15 @@ func NewProxy(host host.Host, listenAddr string, listenPort uint16, targetPort u // streamHandler handles the incoming libp2p streams. We know that the stream will contain an // HTTP request, but strictly speaking we don't care (since CONNECT should act as a transparent -// tunnel). +// tunnel), so we just forward the data. func (p *Proxy) streamHandler(stream network.Stream) { + logrus.Infof("Received new stream %s from peer %s", stream.ID(), stream.Conn().RemotePeer()) target := fmt.Sprintf("localhost:%d", p.targetPort) conn, err := net.Dial("tcp", target) if err != nil { _ = stream.Reset() logrus.Errorf("Error connecting to target host %s: %v", target, err) + return } go transfer(conn, stream) @@ -66,28 +67,38 @@ func (p *Proxy) streamHandler(stream network.Stream) { } // handleTunnel handles the HTTP CONNECT requests -func (p *Proxy) handleTunnel(ctx context.Context, w http.ResponseWriter, req *http.Request) { +func (px *Proxy) handleTunnel(ctx context.Context, w http.ResponseWriter, req *http.Request) { if req.Method != http.MethodConnect { - http.Error(w, "Proxy only supports CONNECT requests", http.StatusBadRequest) - logrus.Errorf("Received invalid request: %v", *req) + http.Error(w, "Proxy only supports CONNECT requests", http.StatusMethodNotAllowed) + logrus.Errorf("[-] Received invalid request: %#v", *req) return } - parts := strings.Split(req.URL.Host, ":") - peerID, err := peer.IDFromBytes([]byte(parts[0])) + logrus.Debugf("Received CONNECT request %#v", req) + parts := strings.Split(req.RequestURI, ":") + peerID, err := peer.Decode(parts[0]) if err != nil { http.Error(w, fmt.Sprintf("Invalid peerID '%s'", parts[0]), http.StatusBadRequest) - logrus.Errorf("Invalid PeerID '%s' in host '%s'", parts[0], req.Host) + logrus.Errorf("[-] Invalid PeerID '%s' in host '%s'", parts[0], req.Host) + return + } + + if peerID == px.host.ID() { + http.Error(w, fmt.Sprintf("Cannot establish tunnel to myself: %s", peerID), http.StatusBadRequest) + logrus.Errorf("[-] Tried to establish tunnel to myself") return } - destStream, err := p.host.NewStream(ctx, peerID, proxyProtocol) + logrus.Infof("Creating CONNECT tunnel from %s to peer %s", req.RemoteAddr, peerID) + + destStream, err := px.host.NewStream(ctx, peerID, ProxyProtocol) if err != nil { http.Error(w, err.Error(), http.StatusServiceUnavailable) logrus.Errorf("Error while creating stream to peer %s: %v", peerID, err) return } + logrus.Debug("Stream established, hijacking") hijacker, ok := w.(http.Hijacker) if !ok { http.Error(w, "Hijacking not supported", http.StatusInternalServerError) @@ -102,40 +113,44 @@ func (p *Proxy) handleTunnel(ctx context.Context, w http.ResponseWriter, req *ht return } - w.WriteHeader(http.StatusOK) + logrus.Debug("Sending response header") + hdr := fmt.Sprintf("%s 200 OK\n\n", req.Proto) + _, err = clientConn.Write([]byte(hdr)) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + logrus.Errorf("Error while sending response header: %v", err) + return + } - go transfer(destStream, clientConn) + logrus.Debug("Starting transfer") go transfer(clientConn, destStream) + go transfer(destStream, clientConn) } -func transfer(destination io.WriteCloser, source io.ReadCloser) { - defer closeStream(source) - defer closeStream(destination) +func transfer(dst io.WriteCloser, src io.ReadCloser) { + defer closeStream(src) + defer closeStream(dst) - if _, err := io.Copy(destination, source); err != nil { + if _, err := io.Copy(dst, src); err != nil { logrus.Errorf("Error during transfer: %v", err) } } func closeStream(s io.Closer) { - err := s.Close() - if err != nil { + if err := s.Close(); err != nil { logrus.Errorf("Error closing stream: %v", err) } } func (p *Proxy) Start(ctx context.Context) { - p.host.SetStreamHandler(proxyProtocol, p.streamHandler) + p.host.SetStreamHandler(ProxyProtocol, p.streamHandler) server := &http.Server{ Addr: fmt.Sprintf("%s:%d", p.listenAddr, p.listenPort), Handler: http.HandlerFunc( func(w http.ResponseWriter, req *http.Request) { - go p.handleTunnel(ctx, w, req) + p.handleTunnel(ctx, w, req) }), - // Disable HTTP/2. - // TODO Is this necessary, since we're not doing HTTPS? - TLSNextProto: make(map[string]func(*http.Server, *tls.Conn, http.Handler)), } if err := server.ListenAndServe(); err != nil { diff --git a/pkg/tests/integration/tracker_test.go b/pkg/tests/integration/tracker_test.go index 40a21344..f03666e8 100644 --- a/pkg/tests/integration/tracker_test.go +++ b/pkg/tests/integration/tracker_test.go @@ -2,15 +2,70 @@ package masa_test import ( + "bufio" "context" "fmt" + "net" + "net/http" + "net/url" . "github.com/masa-finance/masa-oracle/node" "github.com/masa-finance/masa-oracle/pkg/config" + "github.com/masa-finance/masa-oracle/pkg/network" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" ) +// setupNodes sets up two Masa nodes and connects them together +// addOpts are additional options to add to the node for the specific test +func setupNodes(ctx context.Context, addOpts ...Option) (*OracleNode, *OracleNode) { + opts := []Option{EnableStaked, EnableRandomIdentity} + opts = append(opts, addOpts...) + + n, err := NewOracleNode( + ctx, + config.WithConstantOptions(opts...)..., + ) + Expect(err).ToNot(HaveOccurred()) + + err = n.Start() + Expect(err).ToNot(HaveOccurred()) + + addrs, err := n.GetP2PMultiAddrs() + Expect(err).ToNot(HaveOccurred()) + + var bootNodes []string + for _, addr := range addrs { + bootNodes = append(bootNodes, addr.String()) + } + + By(fmt.Sprintf("Generating second node with bootnodes %+v", bootNodes)) + opts = append(opts, WithBootNodes(bootNodes...)) + + n2, err := NewOracleNode( + ctx, + config.WithConstantOptions(opts...)..., + ) + Expect(err).ToNot(HaveOccurred()) + Expect(n.Host.ID()).ToNot(Equal(n2.Host.ID())) + + err = n2.Start() + Expect(err).ToNot(HaveOccurred()) + + // Wait for the nodes to see each other in their respective nodeTracker + Eventually(func() bool { + datas := n2.NodeTracker.GetAllNodeData() + return len(datas) == 2 + }, "30s").Should(BeTrue()) + + Eventually(func() bool { + datas := n.NodeTracker.GetAllNodeData() + return len(datas) == 2 + }, "30s").Should(BeTrue()) + + return n, n2 +} + var _ = Describe("Oracle integration tests", func() { Context("NodeData distribution", func() { It("is distributed across two nodes", func() { @@ -18,62 +73,103 @@ var _ = Describe("Oracle integration tests", func() { ctx, cancel := context.WithCancel(ctx) defer cancel() - n, err := NewOracleNode( - ctx, - config.WithConstantOptions( - EnableStaked, - EnableRandomIdentity, - )..., - ) + n, n2 := setupNodes(ctx) + + data := n.NodeTracker.GetAllNodeData() + + peerIds := []string{} + for _, d := range data { + peerIds = append(peerIds, d.PeerId.String()) + } + + Expect(peerIds).To(ContainElement(n.Host.ID().String())) + Expect(peerIds).To(ContainElement(n2.Host.ID().String())) + }) + }) + + Context("CONNECT proxy", func() { + It("tunnels the connection", func() { + ctx := context.Background() + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + data := []byte("Time is an illusion. Lunchtime, doubly so.\n") + + // Simple echo server + lis, err := net.Listen("tcp", "127.0.0.1:14242") Expect(err).ToNot(HaveOccurred()) - err = n.Start() + go func(l net.Listener) { + conn, err := l.Accept() + Expect(err).ToNot(HaveOccurred()) + + buf := make([]byte, len(data)) + _, err = conn.Read(buf) + Expect(err).ToNot(HaveOccurred()) + Expect(buf).To(Equal(data)) + + _, err = conn.Write(buf) + Expect(err).ToNot(HaveOccurred()) + + err = conn.Close() + Expect(err).ToNot(HaveOccurred()) + }(lis) + + n, n2 := setupNodes(ctx, IsProxy) + + // Create and start the proxies + p, err := network.NewProxy(n.Host, "127.0.0.1", 24242, 14242) Expect(err).ToNot(HaveOccurred()) + go func(ctx context.Context) { + p.Start(ctx) + }(ctx) - addrs, err := n.GetP2PMultiAddrs() + p2, err := network.NewProxy(n2.Host, "127.0.0.1", 34242, 14242) Expect(err).ToNot(HaveOccurred()) + go func(ctx context.Context) { + p2.Start(ctx) + }(ctx) - var bootNodes []string - for _, addr := range addrs { - bootNodes = append(bootNodes, addr.String()) + // Establish the proxy connection + target := fmt.Sprintf("%s:0", n2.Host.ID()) + // This is ridiculous but it seems that Go's http.Request makes assumptions that don't work with CONNECT + // BUT we need the req to properly read the response ¯\_(ツ)_/¯ + rawReq := fmt.Sprintf("CONNECT %s HTTP/1.1\n\n", target) + req := &http.Request{ + Method: "CONNECT", + URL: &url.URL{Host: target}, + Header: make(http.Header), } - By(fmt.Sprintf("Generating second node with bootnodes %+v", bootNodes)) - n2, err := NewOracleNode( - ctx, - config.WithConstantOptions( - EnableStaked, - WithBootNodes(bootNodes...), - EnableRandomIdentity, - )..., - ) - Expect(err).ToNot(HaveOccurred()) - Expect(n.Host.ID()).ToNot(Equal(n2.Host.ID())) + // Wait until the proxy is listening + var conn net.Conn + for { + conn, err = net.Dial("tcp", "127.0.0.1:24242") + if err == nil { + break + } else { + Expect(err.Error()).To(ContainSubstring("connection refused")) + } + } - err = n2.Start() + // Send the CONNECT request, wait for the 200 to indicate that the tunnel is established + _, err = conn.Write([]byte(rawReq)) Expect(err).ToNot(HaveOccurred()) - // Wait for the nodes to see each others in their respective - // nodeTracker - Eventually(func() bool { - datas := n2.NodeTracker.GetAllNodeData() - return len(datas) == 2 - }, "30s").Should(BeTrue()) - - Eventually(func() bool { - datas := n.NodeTracker.GetAllNodeData() - return len(datas) == 2 - }, "30s").Should(BeTrue()) + resp, err := http.ReadResponse(bufio.NewReader(conn), req) + Expect(err).ToNot(HaveOccurred()) + Expect(resp.StatusCode).To(Equal(http.StatusOK)) - data := n.NodeTracker.GetAllNodeData() + // Now send the data and wait for it to come back + _, err = conn.Write(data) + Expect(err).ToNot(HaveOccurred()) - peerIds := []string{} - for _, d := range data { - peerIds = append(peerIds, d.PeerId.String()) - } + buf := make([]byte, len(data)) + _, err = resp.Body.Read(buf) + Expect(err).ToNot(HaveOccurred()) + Expect(buf).To(Equal(data)) - Expect(peerIds).To(ContainElement(n.Host.ID().String())) - Expect(peerIds).To(ContainElement(n2.Host.ID().String())) + resp.Body.Close() }) }) })