Skip to content

Commit

Permalink
Fixes and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mcamou committed Nov 29, 2024
1 parent c454cb9 commit 46a177c
Show file tree
Hide file tree
Showing 6 changed files with 274 additions and 137 deletions.
6 changes: 6 additions & 0 deletions node/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ type NodeOption struct {
IsTelegramScraper bool
IsWebScraper bool

IsProxy bool

Bootnodes []string
RandomIdentity bool
Services []func(ctx context.Context, node *OracleNode)
Expand Down Expand Up @@ -87,6 +89,10 @@ var IsWebScraper = func(o *NodeOption) {
o.IsWebScraper = true
}

var IsProxy = func(o *NodeOption) {
o.IsProxy = true
}

func (a *NodeOption) Apply(opts ...Option) {
for _, opt := range opts {
opt(a)
Expand Down
6 changes: 5 additions & 1 deletion node/oracle_node.go
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,11 @@ func (node *OracleNode) Start() (err error) {
go p(node.Context, node)
}

go myNetwork.Discover(node.Context, node.Options.Bootnodes, node.Host, node.DHT, node.Protocol)
protocols := []protocol.ID{node.Protocol}
if node.Options.IsProxy {
protocols = append(protocols, myNetwork.ProxyProtocol)
}
go myNetwork.Discover(node.Context, node.Options.Bootnodes, node.Host, node.DHT, protocols)

nodeData := node.NodeTracker.GetNodeData(node.Host.ID().String())
if nodeData == nil {
Expand Down
4 changes: 4 additions & 0 deletions pkg/config/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ func InitOptions(cfg *AppConfig) ([]node.Option, *workers.WorkHandlerManager, *p
masaNodeOptions = append(masaNodeOptions, node.IsWebScraper)
}

if cfg.ProxyEnabled {
masaNodeOptions = append(masaNodeOptions, node.IsProxy)
}

workHandlerManager := workers.NewWorkHandlerManager(workerManagerOptions...)
blockChainEventTracker := node.NewBlockChain()
pubKeySub := &pubsub.PublicKeySubscriptionHandler{}
Expand Down
154 changes: 83 additions & 71 deletions pkg/network/discover.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,26 +22,34 @@ import (
// It initializes discovery via the DHT and advertises this node.
// It runs discovery in a loop with a ticker, re-advertising and finding new peers.
// For each discovered peer, it checks if already connected, and if not, dials them.
func Discover(ctx context.Context, bootNodes []string, host host.Host, dht *dht.IpfsDHT, protocol protocol.ID) {
func Discover(ctx context.Context, bootNodes []string, host host.Host, dht *dht.IpfsDHT, protocols []protocol.ID) {
var routingDiscovery *routing.RoutingDiscovery
protocolString := string(protocol)
logrus.Infof("[+] Discovering peers for protocol: %s", protocolString)

routingDiscovery = routing.NewRoutingDiscovery(dht)

// Advertise node right away, then it will re-advertise with each ticker interval
logrus.Infof("[+] Attempting to advertise protocol: %s", protocolString)
_, err := routingDiscovery.Advertise(ctx, protocolString)
if err != nil {
logrus.Debugf("[-] Failed to advertise protocol: %v", err)
} else {
logrus.Infof("[+] Successfully advertised protocol %s", protocolString)

protos := []string{}
for _, p := range protocols {
protos = append(protos, string(p))
}

for _, p := range protos {
logrus.Infof("[+] Discovering peers for protocol: %s", p)

routingDiscovery = routing.NewRoutingDiscovery(dht)

// Advertise node right away, then it will re-advertise with each ticker interval
logrus.Infof("[+] Attempting to advertise protocol: %s", p)
_, err := routingDiscovery.Advertise(ctx, p)
if err != nil {
logrus.Debugf("[-] Failed to advertise protocol: %v", err)
} else {
logrus.Infof("[+] Successfully advertised protocol %s", p)
}
}

ticker := time.NewTicker(time.Minute * 1)
defer ticker.Stop()

var peerChan <-chan peer.AddrInfo
var err error

for {
select {
Expand All @@ -56,66 +64,70 @@ func Discover(ctx context.Context, bootNodes []string, host host.Host, dht *dht.
logrus.Debug("[-] Searching for other peers...")
routingDiscovery = routing.NewRoutingDiscovery(dht)

// Advertise this node
logrus.Debugf("[-] Attempting to advertise protocol: %s", protocolString)
_, err := routingDiscovery.Advertise(ctx, protocolString)
if err != nil {
logrus.Debugf("[-] Failed to advertise protocol with error %v", err)

// Network retry when connectivity is temporarily lost using NewExponentialBackOff
expBackOff := backoff.NewExponentialBackOff()
expBackOff.MaxElapsedTime = time.Second * 10
err := backoff.Retry(func() error {
peerChan, err = routingDiscovery.FindPeers(ctx, protocolString)
return err
}, expBackOff)
if err != nil {
logrus.Warningf("[-] Retry failed to find peers: %v", err)
}

} else {
logrus.Infof("[+] Successfully advertised protocol: %s", protocolString)
}
for _, protocolString := range protos {
// Advertise this node
logrus.Debugf("[-] Attempting to advertise protocol: %s", protocolString)
if _, err := routingDiscovery.Advertise(ctx, protocolString); err != nil {
logrus.Debugf("[-] Failed to advertise protocol with error %v", err)

// Network retry when connectivity is temporarily lost using NewExponentialBackOff
expBackOff := backoff.NewExponentialBackOff()
expBackOff.MaxElapsedTime = time.Second * 10
err = backoff.Retry(func() error {
peerChan, err = routingDiscovery.FindPeers(ctx, protocolString)
return err
}, expBackOff)
if err != nil {
logrus.Warningf("[-] Retry failed to find peers: %v", err)
}

// Use the routing discovery to find peers.
peerChan, err = routingDiscovery.FindPeers(ctx, protocolString)
if err != nil {
logrus.Errorf("[-] Failed to find peers: %v", err)
} else {
logrus.Debug("[+] Successfully started finding peers")
}
select {
case availPeer, ok := <-peerChan:
if !ok {
logrus.Info("[+] Peer channel closed, restarting discovery")
break
}
// validating proper peers to connect to
availPeerAddrInfo := peer.AddrInfo{
ID: availPeer.ID,
Addrs: availPeer.Addrs,
} else {
logrus.Infof("[+] Successfully advertised protocol: %s", protocolString)
}
if availPeerAddrInfo.ID == host.ID() {
logrus.Debugf("Skipping connect to self: %s", availPeerAddrInfo.ID.String())
continue

// Use the routing discovery to find peers.
peerChan, err = routingDiscovery.FindPeers(ctx, protocolString)
if err != nil {
logrus.Errorf("[-] Failed to find peers: %v", err)
} else {
logrus.Debug("[+] Successfully started finding peers")
}
if len(availPeerAddrInfo.Addrs) == 0 {
for _, bn := range bootNodes {
bootNode := strings.Split(bn, "/")[len(strings.Split(bn, "/"))-1]
if availPeerAddrInfo.ID.String() != bootNode {
logrus.Warningf("Skipping connect to non bootnode peer with no multiaddress: %s", availPeerAddrInfo.ID.String())
continue

select {
case availPeer, ok := <-peerChan:
if !ok {
logrus.Info("[+] Peer channel closed, restarting discovery")
break
}
// validating proper peers to connect to
availPeerAddrInfo := peer.AddrInfo{
ID: availPeer.ID,
Addrs: availPeer.Addrs,
}
if availPeerAddrInfo.ID == host.ID() {
logrus.Debugf("Skipping connect to self: %s", availPeerAddrInfo.ID.String())
continue
}
if len(availPeerAddrInfo.Addrs) == 0 {
for _, bn := range bootNodes {
bootNode := strings.Split(bn, "/")[len(strings.Split(bn, "/"))-1]
if availPeerAddrInfo.ID.String() != bootNode {
logrus.Warningf("Skipping connect to non bootnode peer with no multiaddress: %s", availPeerAddrInfo.ID.String())
continue
}
}
}
}
logrus.Infof("[+] Available Peer: %s", availPeer.String())

if host.Network().Connectedness(availPeer.ID) != network.Connected {
if isConnectedToBootnode(host, bootNodes) {
_, err := host.Network().DialPeer(ctx, availPeer.ID)
if err != nil {
logrus.Warningf("[-] Failed to connect to peer %s, will retry...", availPeer.ID.String())
continue
logrus.Infof("[+] Available Peer: %s", availPeer.String())

if host.Network().Connectedness(availPeer.ID) != network.Connected {
if isConnectedToBootnode(host, bootNodes) {
_, err := host.Network().DialPeer(ctx, availPeer.ID)
if err != nil {
logrus.Warningf("[-] Failed to connect to peer %s, will retry...", availPeer.ID.String())
continue
} else {
logrus.Infof("[+] Connected to peer %s", availPeer.ID.String())
}
} else {
for _, bn := range bootNodes {
if len(bn) > 0 {
Expand All @@ -125,10 +137,10 @@ func Discover(ctx context.Context, bootNodes []string, host host.Host, dht *dht.
}
}
}
case <-ctx.Done():
logrus.Info("[-] Stopping peer discovery")
return
}
case <-ctx.Done():
logrus.Info("[-] Stopping peer discovery")
return
}
}
}
Expand Down
61 changes: 38 additions & 23 deletions pkg/network/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package network

import (
"context"
"crypto/tls"
"fmt"
"io"
"net"
Expand Down Expand Up @@ -39,7 +38,7 @@ type Proxy struct {
targetPort uint16
}

const proxyProtocol = "/connect-proxy/1.0.0"
const ProxyProtocol = "/masa/connect-proxy/0.0.1"

func NewProxy(host host.Host, listenAddr string, listenPort uint16, targetPort uint16) (*Proxy, error) {
addr := net.ParseIP(listenAddr)
Expand All @@ -52,42 +51,54 @@ func NewProxy(host host.Host, listenAddr string, listenPort uint16, targetPort u

// streamHandler handles the incoming libp2p streams. We know that the stream will contain an
// HTTP request, but strictly speaking we don't care (since CONNECT should act as a transparent
// tunnel).
// tunnel), so we just forward the data.
func (p *Proxy) streamHandler(stream network.Stream) {
logrus.Infof("Received new stream %s from peer %s", stream.ID(), stream.Conn().RemotePeer())
target := fmt.Sprintf("localhost:%d", p.targetPort)
conn, err := net.Dial("tcp", target)
if err != nil {
_ = stream.Reset()
logrus.Errorf("Error connecting to target host %s: %v", target, err)
return
}

go transfer(conn, stream)
go transfer(stream, conn)
}

// handleTunnel handles the HTTP CONNECT requests
func (p *Proxy) handleTunnel(ctx context.Context, w http.ResponseWriter, req *http.Request) {
func (px *Proxy) handleTunnel(ctx context.Context, w http.ResponseWriter, req *http.Request) {
if req.Method != http.MethodConnect {
http.Error(w, "Proxy only supports CONNECT requests", http.StatusBadRequest)
logrus.Errorf("Received invalid request: %v", *req)
http.Error(w, "Proxy only supports CONNECT requests", http.StatusMethodNotAllowed)
logrus.Errorf("[-] Received invalid request: %#v", *req)
return
}

parts := strings.Split(req.URL.Host, ":")
peerID, err := peer.IDFromBytes([]byte(parts[0]))
logrus.Debugf("Received CONNECT request %#v", req)
parts := strings.Split(req.RequestURI, ":")
peerID, err := peer.Decode(parts[0])
if err != nil {
http.Error(w, fmt.Sprintf("Invalid peerID '%s'", parts[0]), http.StatusBadRequest)
logrus.Errorf("Invalid PeerID '%s' in host '%s'", parts[0], req.Host)
logrus.Errorf("[-] Invalid PeerID '%s' in host '%s'", parts[0], req.Host)
return
}

if peerID == px.host.ID() {
http.Error(w, fmt.Sprintf("Cannot establish tunnel to myself: %s", peerID), http.StatusBadRequest)
logrus.Errorf("[-] Tried to establish tunnel to myself")
return
}

destStream, err := p.host.NewStream(ctx, peerID, proxyProtocol)
logrus.Infof("Creating CONNECT tunnel from %s to peer %s", req.RemoteAddr, peerID)

destStream, err := px.host.NewStream(ctx, peerID, ProxyProtocol)
if err != nil {
http.Error(w, err.Error(), http.StatusServiceUnavailable)
logrus.Errorf("Error while creating stream to peer %s: %v", peerID, err)
return
}

logrus.Debug("Stream established, hijacking")
hijacker, ok := w.(http.Hijacker)
if !ok {
http.Error(w, "Hijacking not supported", http.StatusInternalServerError)
Expand All @@ -102,40 +113,44 @@ func (p *Proxy) handleTunnel(ctx context.Context, w http.ResponseWriter, req *ht
return
}

w.WriteHeader(http.StatusOK)
logrus.Debug("Sending response header")
hdr := fmt.Sprintf("%s 200 OK\n\n", req.Proto)
_, err = clientConn.Write([]byte(hdr))
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
logrus.Errorf("Error while sending response header: %v", err)
return
}

go transfer(destStream, clientConn)
logrus.Debug("Starting transfer")
go transfer(clientConn, destStream)
go transfer(destStream, clientConn)
}

func transfer(destination io.WriteCloser, source io.ReadCloser) {
defer closeStream(source)
defer closeStream(destination)
func transfer(dst io.WriteCloser, src io.ReadCloser) {
defer closeStream(src)
defer closeStream(dst)

if _, err := io.Copy(destination, source); err != nil {
if _, err := io.Copy(dst, src); err != nil {
logrus.Errorf("Error during transfer: %v", err)
}
}

func closeStream(s io.Closer) {
err := s.Close()
if err != nil {
if err := s.Close(); err != nil {
logrus.Errorf("Error closing stream: %v", err)
}
}

func (p *Proxy) Start(ctx context.Context) {
p.host.SetStreamHandler(proxyProtocol, p.streamHandler)
p.host.SetStreamHandler(ProxyProtocol, p.streamHandler)

server := &http.Server{
Addr: fmt.Sprintf("%s:%d", p.listenAddr, p.listenPort),
Handler: http.HandlerFunc(
func(w http.ResponseWriter, req *http.Request) {
go p.handleTunnel(ctx, w, req)
p.handleTunnel(ctx, w, req)
}),
// Disable HTTP/2.
// TODO Is this necessary, since we're not doing HTTPS?
TLSNextProto: make(map[string]func(*http.Server, *tls.Conn, http.Handler)),
}

if err := server.ListenAndServe(); err != nil {
Expand Down
Loading

0 comments on commit 46a177c

Please sign in to comment.