Skip to content

Commit

Permalink
feat(tcpreuse): add options for sharing TCP listeners amongst TCP, WS…
Browse files Browse the repository at this point in the history
…, and WSS transports
  • Loading branch information
aschmahmann committed Sep 27, 2024
1 parent 9038a72 commit b6f1fb0
Show file tree
Hide file tree
Showing 10 changed files with 635 additions and 24 deletions.
32 changes: 29 additions & 3 deletions p2p/transport/tcp/tcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/core/transport"
"github.com/libp2p/go-libp2p/p2p/net/reuseport"
"github.com/libp2p/go-libp2p/p2p/transport/tcpreuse"

logging "github.com/ipfs/go-log/v2"
ma "github.com/multiformats/go-multiaddr"
Expand All @@ -33,6 +34,9 @@ type canKeepAlive interface {

var _ canKeepAlive = &net.TCPConn{}

// Deprecated: Use tcpreuse.ReuseportIsAvailable
var ReuseportIsAvailable = tcpreuse.ReuseportIsAvailable

func tryKeepAlive(conn net.Conn, keepAlive bool) {
keepAliveConn, ok := conn.(canKeepAlive)
if !ok {
Expand Down Expand Up @@ -113,6 +117,13 @@ func WithMetrics() Option {
}
}

func WithSharedTCP(mgr *tcpreuse.ConnMgr) Option {
return func(tr *TcpTransport) error {
tr.sharedTcp = mgr
return nil
}
}

// TcpTransport is the TCP transport.
type TcpTransport struct {
// Connection upgrader for upgrading insecure stream connections to
Expand All @@ -122,6 +133,9 @@ type TcpTransport struct {
disableReuseport bool // Explicitly disable reuseport.
enableMetrics bool

// share and demultiplex TCP listeners across multiple transports
sharedTcp *tcpreuse.ConnMgr

// TCP connect timeout
connectTimeout time.Duration

Expand Down Expand Up @@ -168,6 +182,10 @@ func (t *TcpTransport) maDial(ctx context.Context, raddr ma.Multiaddr) (manet.Co
defer cancel()
}

if t.sharedTcp != nil {
return t.sharedTcp.DialContext(ctx, raddr)
}

if t.UseReuseport() {
return t.reuse.DialContext(ctx, raddr)
}
Expand Down Expand Up @@ -233,10 +251,10 @@ func (t *TcpTransport) dialWithScope(ctx context.Context, raddr ma.Multiaddr, p

// UseReuseport returns true if reuseport is enabled and available.
func (t *TcpTransport) UseReuseport() bool {
return !t.disableReuseport && ReuseportIsAvailable()
return !t.disableReuseport && tcpreuse.ReuseportIsAvailable()
}

func (t *TcpTransport) maListen(laddr ma.Multiaddr) (manet.Listener, error) {
func (t *TcpTransport) unsharedMAListen(laddr ma.Multiaddr) (manet.Listener, error) {
if t.UseReuseport() {
return t.reuse.Listen(laddr)
}
Expand All @@ -245,10 +263,18 @@ func (t *TcpTransport) maListen(laddr ma.Multiaddr) (manet.Listener, error) {

// Listen listens on the given multiaddr.
func (t *TcpTransport) Listen(laddr ma.Multiaddr) (transport.Listener, error) {
list, err := t.maListen(laddr)
var list manet.Listener
var err error

if t.sharedTcp == nil {
list, err = t.unsharedMAListen(laddr)
} else {
list, err = t.sharedTcp.DemultiplexedListen(laddr, tcpreuse.MultistreamSelect)
}
if err != nil {
return nil, err
}

if t.enableMetrics {
list = newTracingListener(&tcpListener{list, 0})
}
Expand Down
13 changes: 7 additions & 6 deletions p2p/transport/tcp/tcp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"github.com/libp2p/go-libp2p/core/transport"
"github.com/libp2p/go-libp2p/p2p/muxer/yamux"
tptu "github.com/libp2p/go-libp2p/p2p/net/upgrader"
"github.com/libp2p/go-libp2p/p2p/transport/tcpreuse"
ttransport "github.com/libp2p/go-libp2p/p2p/transport/testsuite"

ma "github.com/multiformats/go-multiaddr"
Expand Down Expand Up @@ -41,9 +42,9 @@ func TestTcpTransport(t *testing.T) {
zero := "/ip4/127.0.0.1/tcp/0"
ttransport.SubtestTransport(t, ta, tb, zero, peerA)

envReuseportVal = false
tcpreuse.EnvReuseportVal = false
}
envReuseportVal = true
tcpreuse.EnvReuseportVal = true
}

func TestTcpTransportWithMetrics(t *testing.T) {
Expand Down Expand Up @@ -126,9 +127,9 @@ func TestTcpTransportCantDialDNS(t *testing.T) {
t.Fatal("shouldn't be able to dial dns")
}

envReuseportVal = false
tcpreuse.EnvReuseportVal = false
}
envReuseportVal = true
tcpreuse.EnvReuseportVal = true
}

func TestTcpTransportCantListenUtp(t *testing.T) {
Expand All @@ -143,9 +144,9 @@ func TestTcpTransportCantListenUtp(t *testing.T) {
_, err = tpt.Listen(utpa)
require.Error(t, err, "shouldn't be able to listen on utp addr with tcp transport")

envReuseportVal = false
tcpreuse.EnvReuseportVal = false
}
envReuseportVal = true
tcpreuse.EnvReuseportVal = true
}

func TestDialWithUpdates(t *testing.T) {
Expand Down
240 changes: 240 additions & 0 deletions p2p/transport/tcpreuse/demultiplex.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,240 @@
package tcpreuse

import (
"bufio"
"errors"
"fmt"
"io"
"math"
"net"
"time"

ma "github.com/multiformats/go-multiaddr"
manet "github.com/multiformats/go-multiaddr/net"
)

type peekAble interface {
// Peek returns the next n bytes without advancing the reader. The bytes stop
// being valid at the next read call. If Peek returns fewer than n bytes, it
// also returns an error explaining why the read is short. The error is
// [ErrBufferFull] if n is larger than b's buffer size.
Peek(n int) ([]byte, error)
}

var _ peekAble = (*bufio.Reader)(nil)

type DemultiplexedConnType int

const (
Unknown DemultiplexedConnType = iota
MultistreamSelect
HTTP
TLS
)

func (t DemultiplexedConnType) String() string {
switch t {
case MultistreamSelect:
return "MultistreamSelect"
case HTTP:
return "HTTP"
case TLS:
return "TLS"
default:
return fmt.Sprintf("Unknown(%d)", int(t))
}
}

func (t DemultiplexedConnType) IsKnown() bool {
return t >= 1 || t <= 3
}

func ConnTypeFromConn(c net.Conn) (DemultiplexedConnType, manet.Conn, error) {
if err := c.SetReadDeadline(time.Now().Add(1 * time.Second)); err != nil {
closeErr := c.Close()
return 0, nil, errors.Join(err, closeErr)
}

s, sc, err := ReadSampleFromConn(c)
if err != nil {
closeErr := c.Close()
return 0, nil, errors.Join(err, closeErr)
}

if err := c.SetReadDeadline(time.Time{}); err != nil {
closeErr := c.Close()
return 0, nil, errors.Join(err, closeErr)
}

if IsMultistreamSelect(s) {
return MultistreamSelect, sc, nil
}
if IsTLS(s) {
return TLS, sc, nil
}
if IsHTTP(s) {
return HTTP, sc, nil
}
return Unknown, sc, nil
}

// ReadSampleFromConn read the sample and returns a reader which still include the sample, so it can be kept undamaged.
// If an error occurs it only return the error.
func ReadSampleFromConn(c net.Conn) (Sample, manet.Conn, error) {
if peekAble, ok := c.(peekAble); ok {
b, err := peekAble.Peek(len(Sample{}))
switch {
case err == nil:
mac, err := manet.WrapNetConn(c)
if err != nil {
return Sample{}, nil, err
}

return Sample(b), mac, nil
case errors.Is(err, bufio.ErrBufferFull):
// fallback to sampledConn
default:
return Sample{}, nil, err
}
}

tcpConnLike, ok := c.(tcpConnInterface)
if !ok {
return Sample{}, nil, fmt.Errorf("expected tcp-like connection")
}

laddr, err := manet.FromNetAddr(c.LocalAddr())
if err != nil {
return Sample{}, nil, fmt.Errorf("failed to convert nconn.LocalAddr: %s", err)
}

raddr, err := manet.FromNetAddr(c.RemoteAddr())
if err != nil {
return Sample{}, nil, fmt.Errorf("failed to convert nconn.RemoteAddr: %s", err)
}

sc := &sampledConn{tcpConnInterface: tcpConnLike, maEndpoints: maEndpoints{laddr: laddr, raddr: raddr}}
_, err = io.ReadFull(c, sc.s[:])
if err != nil {
return Sample{}, nil, err
}

return sc.s, sc, nil
}

// Try out best to mimic a TCPConn's functions
// Note: Skipping `SyscallConn() (syscall.RawConn, error)` since it can be misused given we've read a few bytes from the connection
// If this is an issue here we can revisit the options.
type tcpConnInterface interface {
net.Conn

CloseRead() error
CloseWrite() error

SetLinger(sec int) error
SetKeepAlive(keepalive bool) error
SetKeepAlivePeriod(d time.Duration) error
SetNoDelay(noDelay bool) error
MultipathTCP() (bool, error)

io.ReaderFrom
io.WriterTo
}

type maEndpoints struct {
laddr ma.Multiaddr
raddr ma.Multiaddr
}

// LocalMultiaddr returns the local address associated with
// this connection
func (c *maEndpoints) LocalMultiaddr() ma.Multiaddr {
return c.laddr
}

// RemoteMultiaddr returns the remote address associated with
// this connection
func (c *maEndpoints) RemoteMultiaddr() ma.Multiaddr {
return c.raddr
}

type sampledConn struct {
tcpConnInterface
maEndpoints

s Sample
readFromSample uint8
}

var _ = [math.MaxUint8]struct{}{}[len(Sample{})] // compiletime assert sampledConn.readFromSample wont overflow
var _ io.ReaderFrom = (*sampledConn)(nil)
var _ io.WriterTo = (*sampledConn)(nil)

func (sc *sampledConn) Read(b []byte) (int, error) {
if int(sc.readFromSample) != len(sc.s) {
red := copy(b, sc.s[sc.readFromSample:])
sc.readFromSample += uint8(red)
return red, nil
}

return sc.tcpConnInterface.Read(b)
}

// forward optimizations
func (sc *sampledConn) ReadFrom(r io.Reader) (int64, error) {
return io.Copy(sc.tcpConnInterface, r)
}

// forward optimizations
func (sc *sampledConn) WriteTo(w io.Writer) (total int64, err error) {
if int(sc.readFromSample) != len(sc.s) {
b := sc.s[sc.readFromSample:]
written, err := w.Write(b)
if written < 0 || len(b) < written {
// buggy writer, harden against this
sc.readFromSample = uint8(len(sc.s))
total = int64(len(sc.s))
} else {
sc.readFromSample += uint8(written)
total += int64(written)
}
if err != nil {
return total, err
}
}

written, err := io.Copy(w, sc.tcpConnInterface)
total += written
return total, err
}

type Matcher interface {
Match(s Sample) bool
}

// Sample might evolve over time.
type Sample [3]byte

// Matchers are implemented here instead of in the transports so we can easily fuzz them together.

func IsMultistreamSelect(s Sample) bool {
return string(s[:]) == "\x13/m"
}

func IsHTTP(s Sample) bool {
switch string(s[:]) {
case "GET", "HEA", "POS", "PUT", "DEL", "CON", "OPT", "TRA", "PAT":
return true
default:
return false
}
}

func IsTLS(s Sample) bool {
switch string(s[:]) {
case "\x16\x03\x01", "\x16\x03\x02", "\x16\x03\x03", "\x16\x03\x04":
return true
default:
return false
}
}
Loading

0 comments on commit b6f1fb0

Please sign in to comment.