diff --git a/p2p/test/transport/gating_test.go b/p2p/test/transport/gating_test.go index df53da6eeb..06b37cbbc2 100644 --- a/p2p/test/transport/gating_test.go +++ b/p2p/test/transport/gating_test.go @@ -101,7 +101,7 @@ func TestInterceptSecuredOutgoing(t *testing.T) { connGater.EXPECT().InterceptAddrDial(h2.ID(), gomock.Any()).Return(true), connGater.EXPECT().InterceptSecured(network.DirOutbound, h2.ID(), gomock.Any()).Do(func(_ network.Direction, _ peer.ID, addrs network.ConnMultiaddrs) { // remove the certhash component from WebTransport and WebRTC addresses - require.Equal(t, stripCertHash(h2.Addrs()[0]).String(), addrs.RemoteMultiaddr().String()) + require.Equal(t, h2.Addrs()[0].String(), addrs.RemoteMultiaddr().String()) }), ) err := h1.Connect(ctx, peer.AddrInfo{ID: h2.ID(), Addrs: h2.Addrs()}) @@ -135,8 +135,7 @@ func TestInterceptUpgradedOutgoing(t *testing.T) { connGater.EXPECT().InterceptAddrDial(h2.ID(), gomock.Any()).Return(true), connGater.EXPECT().InterceptSecured(network.DirOutbound, h2.ID(), gomock.Any()).Return(true), connGater.EXPECT().InterceptUpgraded(gomock.Any()).Do(func(c network.Conn) { - // remove the certhash component from WebTransport addresses - require.Equal(t, stripCertHash(h2.Addrs()[0]), c.RemoteMultiaddr()) + require.Equal(t, h2.Addrs()[0], c.RemoteMultiaddr()) require.Equal(t, h1.ID(), c.LocalPeer()) require.Equal(t, h2.ID(), c.RemotePeer()) })) @@ -170,12 +169,12 @@ func TestInterceptAccept(t *testing.T) { // In WebRTC, retransmissions of the STUN packet might cause us to create multiple connections, // if the first connection attempt is rejected. connGater.EXPECT().InterceptAccept(gomock.Any()).Do(func(addrs network.ConnMultiaddrs) { - // remove the certhash component from WebTransport addresses + // remove the certhash component from WebRTC and WebTransport addresses require.Equal(t, stripCertHash(h2.Addrs()[0]), addrs.LocalMultiaddr()) }).AnyTimes() } else { connGater.EXPECT().InterceptAccept(gomock.Any()).Do(func(addrs network.ConnMultiaddrs) { - // remove the certhash component from WebTransport addresses + // remove the certhash component from WebRTC and WebTransport addresses require.Equal(t, stripCertHash(h2.Addrs()[0]), addrs.LocalMultiaddr()) }) } @@ -213,8 +212,7 @@ func TestInterceptSecuredIncoming(t *testing.T) { gomock.InOrder( connGater.EXPECT().InterceptAccept(gomock.Any()).Return(true), connGater.EXPECT().InterceptSecured(network.DirInbound, h1.ID(), gomock.Any()).Do(func(_ network.Direction, _ peer.ID, addrs network.ConnMultiaddrs) { - // remove the certhash component from WebTransport addresses - require.Equal(t, stripCertHash(h2.Addrs()[0]), addrs.LocalMultiaddr()) + require.Equal(t, h2.Addrs()[0], addrs.LocalMultiaddr()) }), ) h1.Peerstore().AddAddrs(h2.ID(), h2.Addrs(), time.Hour) @@ -248,7 +246,7 @@ func TestInterceptUpgradedIncoming(t *testing.T) { connGater.EXPECT().InterceptSecured(network.DirInbound, h1.ID(), gomock.Any()).Return(true), connGater.EXPECT().InterceptUpgraded(gomock.Any()).Do(func(c network.Conn) { // remove the certhash component from WebTransport addresses - require.Equal(t, stripCertHash(h2.Addrs()[0]), c.LocalMultiaddr()) + require.Equal(t, h2.Addrs()[0], c.LocalMultiaddr()) require.Equal(t, h1.ID(), c.RemotePeer()) require.Equal(t, h2.ID(), c.LocalPeer()) }), diff --git a/p2p/test/transport/transport_test.go b/p2p/test/transport/transport_test.go index 7cfab5f3ca..050bed4cbb 100644 --- a/p2p/test/transport/transport_test.go +++ b/p2p/test/transport/transport_test.go @@ -3,10 +3,17 @@ package transport_integration import ( "bytes" "context" + "crypto/ecdsa" + "crypto/elliptic" "crypto/rand" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" "errors" "fmt" "io" + "math/big" "net" "runtime" "strings" @@ -30,8 +37,9 @@ import ( "github.com/libp2p/go-libp2p/p2p/net/swarm" "github.com/libp2p/go-libp2p/p2p/protocol/ping" "github.com/libp2p/go-libp2p/p2p/security/noise" - tls "github.com/libp2p/go-libp2p/p2p/security/tls" + sectls "github.com/libp2p/go-libp2p/p2p/security/tls" libp2pwebrtc "github.com/libp2p/go-libp2p/p2p/transport/webrtc" + "github.com/libp2p/go-libp2p/p2p/transport/websocket" "go.uber.org/mock/gomock" ma "github.com/multiformats/go-multiaddr" @@ -48,6 +56,7 @@ type TransportTestCaseOpts struct { NoRcmgr bool ConnGater connmgr.ConnectionGater ResourceManager network.ResourceManager + HostSeed string } func transformOpts(opts TransportTestCaseOpts) []config.Option { @@ -87,7 +96,7 @@ var transportsToTest = []TransportTestCase{ Name: "TCP / TLS / Yamux", HostGenerator: func(t *testing.T, opts TransportTestCaseOpts) host.Host { libp2pOpts := transformOpts(opts) - libp2pOpts = append(libp2pOpts, libp2p.Security(tls.ID, tls.New)) + libp2pOpts = append(libp2pOpts, libp2p.Security(sectls.ID, sectls.New)) libp2pOpts = append(libp2pOpts, libp2p.Muxer(yamux.ID, yamux.DefaultTransport)) if opts.NoListen { libp2pOpts = append(libp2pOpts, libp2p.NoListenAddrs) @@ -113,6 +122,26 @@ var transportsToTest = []TransportTestCase{ return h }, }, + { + Name: "Secure WebSocket with CA Certificate", + HostGenerator: func(t *testing.T, opts TransportTestCaseOpts) host.Host { + libp2pOpts := transformOpts(opts) + wsOpts := []interface{}{websocket.WithTLSClientConfig(&tls.Config{InsecureSkipVerify: true})} + if opts.NoListen { + libp2pOpts = append(libp2pOpts, libp2p.NoListenAddrs) + } else { + dnsName := fmt.Sprintf("example%s.com", opts.HostSeed) + cert, err := generateSelfSignedCert(dnsName) + require.NoError(t, err) + libp2pOpts = append(libp2pOpts, libp2p.ListenAddrStrings(fmt.Sprintf("/ip4/127.0.0.1/tcp/0/tls/sni/%s/ws", dnsName))) + wsOpts = append(wsOpts, websocket.WithTLSConfig(&tls.Config{Certificates: []tls.Certificate{cert}})) + } + libp2pOpts = append(libp2pOpts, libp2p.Transport(websocket.New, wsOpts...)) + h, err := libp2p.New(libp2pOpts...) + require.NoError(t, err) + return h + }, + }, { Name: "QUIC", HostGenerator: func(t *testing.T, opts TransportTestCaseOpts) host.Host { @@ -158,6 +187,46 @@ var transportsToTest = []TransportTestCase{ }, } +func generateSelfSignedCert(dnsName string) (tls.Certificate, error) { + priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + return tls.Certificate{}, err + } + + serialNumber, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128)) + if err != nil { + return tls.Certificate{}, err + } + + template := x509.Certificate{ + SerialNumber: serialNumber, + Subject: pkix.Name{ + Organization: []string{"My Organization"}, + }, + NotBefore: time.Now(), + NotAfter: time.Now().Add(365 * 24 * time.Hour), // Valid for 1 year + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + DNSNames: []string{dnsName}, + } + + certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv) + if err != nil { + return tls.Certificate{}, err + } + + certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER}) + privDER, err := x509.MarshalECPrivateKey(priv) + if err != nil { + return tls.Certificate{}, err + } + privPEM := pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: privDER}) + + // Load the certificate and key into tls.Certificate + return tls.X509KeyPair(certPEM, privPEM) +} + func TestPing(t *testing.T) { for _, tc := range transportsToTest { t.Run(tc.Name, func(t *testing.T) { @@ -798,3 +867,29 @@ func TestConnClosedWhenRemoteCloses(t *testing.T) { }) } } + +func TestConnMatchingAddress(t *testing.T) { + for _, tc := range transportsToTest { + t.Run(tc.Name, func(t *testing.T) { + server := tc.HostGenerator(t, TransportTestCaseOpts{}) + client1 := tc.HostGenerator(t, TransportTestCaseOpts{NoListen: true}) + client2 := tc.HostGenerator(t, TransportTestCaseOpts{NoListen: true}) + defer server.Close() + defer client1.Close() + defer client2.Close() + + client1.Peerstore().AddAddrs(server.ID(), server.Addrs(), peerstore.PermanentAddrTTL) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + err := client1.Connect(ctx, peer.AddrInfo{ID: server.ID(), Addrs: server.Addrs()}) + require.NoError(t, err) + + client1Conns := client1.Network().ConnsToPeer(server.ID()) + require.Equal(t, 1, len(client1Conns)) + remoteMA := client1Conns[0].RemoteMultiaddr() + + err = client2.Connect(ctx, peer.AddrInfo{ID: server.ID(), Addrs: []ma.Multiaddr{remoteMA}}) + require.NoError(t, err) + }) + } +} diff --git a/p2p/transport/webrtc/listener.go b/p2p/transport/webrtc/listener.go index d4ba3c0550..8c20b4824f 100644 --- a/p2p/transport/webrtc/listener.go +++ b/p2p/transport/webrtc/listener.go @@ -264,14 +264,13 @@ func (l *listener) setupConnection( return nil, err } - localMultiaddrWithoutCerthash, _ := ma.SplitFunc(l.localMultiaddr, func(c ma.Component) bool { return c.Protocol().Code == ma.P_CERTHASH }) conn, err := newConnection( network.DirInbound, w.PeerConnection, l.transport, scope, l.transport.localPeerId, - localMultiaddrWithoutCerthash, + l.localMultiaddr, remotePeer, remotePubKey, remoteMultiaddr, diff --git a/p2p/transport/webrtc/transport.go b/p2p/transport/webrtc/transport.go index c4c16fd402..647f76ddc7 100644 --- a/p2p/transport/webrtc/transport.go +++ b/p2p/transport/webrtc/transport.go @@ -387,7 +387,6 @@ func (t *WebRTCTransport) dial(ctx context.Context, scope network.ConnManagement if err != nil { return nil, err } - remoteMultiaddrWithoutCerthash, _ := ma.SplitFunc(remoteMultiaddr, func(c ma.Component) bool { return c.Protocol().Code == ma.P_CERTHASH }) conn, err := newConnection( network.DirOutbound, @@ -398,7 +397,7 @@ func (t *WebRTCTransport) dial(ctx context.Context, scope network.ConnManagement localAddr, p, remotePubKey, - remoteMultiaddrWithoutCerthash, + remoteMultiaddr, w.IncomingDataChannels, w.PeerConnectionClosedCh, ) diff --git a/p2p/transport/websocket/conn.go b/p2p/transport/websocket/conn.go index 30b70055d0..d6bf35d483 100644 --- a/p2p/transport/websocket/conn.go +++ b/p2p/transport/websocket/conn.go @@ -1,6 +1,9 @@ package websocket import ( + "fmt" + ma "github.com/multiformats/go-multiaddr" + manet "github.com/multiformats/go-multiaddr/net" "io" "net" "sync" @@ -25,10 +28,88 @@ type Conn struct { closeOnce sync.Once readLock, writeLock sync.Mutex + + laddr, raddr *Addr + laddrma, raddrma ma.Multiaddr } var _ net.Conn = (*Conn)(nil) +// NewConn creates a Conn given a regular gorilla/websocket Conn. +func NewOutboundConn(raw *ws.Conn, secure bool, sni string) (*Conn, error) { + laddr := NewAddrWithScheme(raw.LocalAddr().String(), secure) + raddr := NewAddrWithScheme(raw.RemoteAddr().String(), secure) + + laddrma, err := manet.FromNetAddr(laddr) + if err != nil { + return nil, fmt.Errorf("failed to convert connection address to multiaddr: %s", err) + } + + raddrma, err := manet.FromNetAddr(raddr) + if err != nil { + return nil, fmt.Errorf("failed to convert connection address to multiaddr: %s", err) + } + + if secure { + if withoutWSS := raddrma.Decapsulate(ma.StringCast("/wss")); withoutWSS.Equal(raddrma) { + return nil, fmt.Errorf("missing wss component from converted multiaddr") + } else { + tlsSniWsMa, err := ma.NewMultiaddr(fmt.Sprintf("/tls/sni/%s/ws", sni)) + if err != nil { + return nil, fmt.Errorf("failed to convert connection address to multiaddr: %s", err) + } + raddrma = withoutWSS.Encapsulate(tlsSniWsMa) + } + } + + return &Conn{ + Conn: raw, + secure: secure, + DefaultMessageType: ws.BinaryMessage, + laddr: laddr, + raddr: raddr, + laddrma: laddrma, + raddrma: raddrma, + }, nil +} + +func NewInboundConn(raw *ws.Conn, secure bool, sni string) (*Conn, error) { + laddr := NewAddrWithScheme(raw.LocalAddr().String(), secure) + raddr := NewAddrWithScheme(raw.RemoteAddr().String(), secure) + + laddrma, err := manet.FromNetAddr(laddr) + if err != nil { + return nil, fmt.Errorf("failed to convert connection address to multiaddr: %s", err) + } + + raddrma, err := manet.FromNetAddr(raddr) + if err != nil { + return nil, fmt.Errorf("failed to convert connection address to multiaddr: %s", err) + } + + if secure { + if withoutWSS := laddrma.Decapsulate(ma.StringCast("/wss")); withoutWSS.Equal(laddrma) { + return nil, fmt.Errorf("missing wss component from converted multiaddr") + } else { + tlsSniWsMa, err := ma.NewMultiaddr(fmt.Sprintf("/tls/sni/%s/ws", sni)) + if err != nil { + return nil, fmt.Errorf("failed to convert connection address to multiaddr: %s", err) + } + laddrma = withoutWSS.Encapsulate(tlsSniWsMa) + } + } + + return &Conn{ + Conn: raw, + secure: secure, + DefaultMessageType: ws.BinaryMessage, + laddr: laddr, + raddr: raddr, + laddrma: laddrma, + raddrma: raddrma, + }, nil +} + // NewConn creates a Conn given a regular gorilla/websocket Conn. func NewConn(raw *ws.Conn, secure bool) *Conn { return &Conn{ @@ -122,11 +203,19 @@ func (c *Conn) Close() error { } func (c *Conn) LocalAddr() net.Addr { - return NewAddrWithScheme(c.Conn.LocalAddr().String(), c.secure) + return c.laddr } func (c *Conn) RemoteAddr() net.Addr { - return NewAddrWithScheme(c.Conn.RemoteAddr().String(), c.secure) + return c.raddr +} + +func (c *Conn) LocalMultiaddr() ma.Multiaddr { + return c.laddrma +} + +func (c *Conn) RemoteMultiaddr() ma.Multiaddr { + return c.raddrma } func (c *Conn) SetDeadline(t time.Time) error { diff --git a/p2p/transport/websocket/listener.go b/p2p/transport/websocket/listener.go index 8071ddb814..6875ac1b2d 100644 --- a/p2p/transport/websocket/listener.go +++ b/p2p/transport/websocket/listener.go @@ -112,10 +112,20 @@ func (l *listener) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } + var sni string + if r.TLS != nil { + sni = r.TLS.ServerName + } + mnc, err := NewInboundConn(c, l.isWss, sni) + if err != nil { + _ = c.Close() + return + } + select { - case l.incoming <- NewConn(c, l.isWss): + case l.incoming <- mnc: case <-l.closed: - c.Close() + mnc.Close() } // The connection has been hijacked, it's safe to return. } @@ -126,13 +136,7 @@ func (l *listener) Accept() (manet.Conn, error) { if !ok { return nil, transport.ErrListenerClosed } - - mnc, err := manet.WrapNetConn(c) - if err != nil { - c.Close() - return nil, err - } - return mnc, nil + return c, nil case <-l.closed: return nil, transport.ErrListenerClosed } diff --git a/p2p/transport/websocket/websocket.go b/p2p/transport/websocket/websocket.go index 36818decee..404a11e0eb 100644 --- a/p2p/transport/websocket/websocket.go +++ b/p2p/transport/websocket/websocket.go @@ -188,8 +188,8 @@ func (t *WebsocketTransport) maDial(ctx context.Context, raddr ma.Multiaddr) (ma } isWss := wsurl.Scheme == "wss" dialer := ws.Dialer{HandshakeTimeout: 30 * time.Second} + var sni string if isWss { - sni := "" sni, err = raddr.ValueForProtocol(ma.P_SNI) if err != nil { sni = "" @@ -220,7 +220,7 @@ func (t *WebsocketTransport) maDial(ctx context.Context, raddr ma.Multiaddr) (ma return nil, err } - mnc, err := manet.WrapNetConn(NewConn(wscon, isWss)) + mnc, err := NewOutboundConn(wscon, isWss, sni) if err != nil { wscon.Close() return nil, err diff --git a/p2p/transport/websocket/websocket_test.go b/p2p/transport/websocket/websocket_test.go index 8f912c4138..66704389b0 100644 --- a/p2p/transport/websocket/websocket_test.go +++ b/p2p/transport/websocket/websocket_test.go @@ -9,6 +9,7 @@ import ( "crypto/tls" "crypto/x509" "crypto/x509/pkix" + "encoding/pem" "errors" "fmt" "io" @@ -307,6 +308,84 @@ func TestWebsocketTransport(t *testing.T) { ttransport.SubtestTransport(t, ta, tb, "/ip4/127.0.0.1/tcp/0/ws", peerA) } +func TestWSSTransport(t *testing.T) { + peerA, ua := newUpgrader(t) + + const dnsName = "example.com" + // Generate the self-signed certificate and private key + certPEM, privPEM, err := generateSelfSignedCert(dnsName) + if err != nil { + t.Fatalf("Failed to generate self-signed certificate: %v", err) + } + + // Load the certificate and key into tls.Certificate + cert, err := tls.X509KeyPair(certPEM, privPEM) + if err != nil { + t.Fatalf("Failed to load key pair: %v", err) + } + + // Create a TLS configuration with the certificate + tlsConfig := &tls.Config{ + Certificates: []tls.Certificate{cert}, + } + + ta, err := New(ua, nil, WithTLSConfig(tlsConfig)) + if err != nil { + t.Fatal(err) + } + + _, ub := newUpgrader(t) + cas := x509.NewCertPool() + cas.AppendCertsFromPEM(certPEM) + + tb, err := New(ub, nil, WithTLSClientConfig(&tls.Config{RootCAs: cas})) + if err != nil { + t.Fatal(err) + } + + // Note: the /wss form is not tested as it would require setting up custom DNS resolution + ttransport.SubtestTransport(t, ta, tb, fmt.Sprintf("/ip4/127.0.0.1/tcp/0/tls/sni/%s/ws", dnsName), peerA) +} + +func generateSelfSignedCert(dnsName string) ([]byte, []byte, error) { + priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + return nil, nil, err + } + + serialNumber, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128)) + if err != nil { + return nil, nil, err + } + + template := x509.Certificate{ + SerialNumber: serialNumber, + Subject: pkix.Name{ + Organization: []string{"My Organization"}, + }, + NotBefore: time.Now(), + NotAfter: time.Now().Add(365 * 24 * time.Hour), // Valid for 1 year + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + DNSNames: []string{dnsName}, + } + + certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv) + if err != nil { + return nil, nil, err + } + + certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER}) + privDER, err := x509.MarshalECPrivateKey(priv) + if err != nil { + return nil, nil, err + } + privPEM := pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: privDER}) + + return certPEM, privPEM, nil +} + func isWSS(addr ma.Multiaddr) bool { if _, err := addr.ValueForProtocol(ma.P_WSS); err == nil { return true diff --git a/p2p/transport/webtransport/listener.go b/p2p/transport/webtransport/listener.go index ff611fe927..84eb044f2c 100644 --- a/p2p/transport/webtransport/listener.go +++ b/p2p/transport/webtransport/listener.go @@ -234,10 +234,7 @@ func (l *listener) Accept() (tpt.CapableConn, error) { } func (l *listener) handshake(ctx context.Context, sess *webtransport.Session) (*connSecurityMultiaddrs, error) { - local, err := toWebtransportMultiaddr(sess.LocalAddr()) - if err != nil { - return nil, fmt.Errorf("error determiniting local addr: %w", err) - } + local := l.Multiaddr() remote, err := toWebtransportMultiaddr(sess.RemoteAddr()) if err != nil { return nil, fmt.Errorf("error determiniting remote addr: %w", err) diff --git a/p2p/transport/webtransport/transport.go b/p2p/transport/webtransport/transport.go index acb40f0b89..17efa1bcf3 100644 --- a/p2p/transport/webtransport/transport.go +++ b/p2p/transport/webtransport/transport.go @@ -172,7 +172,7 @@ func (t *transport) dialWithScope(ctx context.Context, raddr ma.Multiaddr, p pee if err != nil { return nil, err } - sconn, err := t.upgrade(ctx, sess, p, certHashes) + sconn, err := t.upgrade(ctx, sess, p, certHashes, raddr) if err != nil { sess.CloseWithError(1, "") return nil, err @@ -230,15 +230,11 @@ func (t *transport) dial(ctx context.Context, addr ma.Multiaddr, url, sni string return sess, conn, err } -func (t *transport) upgrade(ctx context.Context, sess *webtransport.Session, p peer.ID, certHashes []multihash.DecodedMultihash) (*connSecurityMultiaddrs, error) { +func (t *transport) upgrade(ctx context.Context, sess *webtransport.Session, p peer.ID, certHashes []multihash.DecodedMultihash, remote ma.Multiaddr) (*connSecurityMultiaddrs, error) { local, err := toWebtransportMultiaddr(sess.LocalAddr()) if err != nil { return nil, fmt.Errorf("error determining local addr: %w", err) } - remote, err := toWebtransportMultiaddr(sess.RemoteAddr()) - if err != nil { - return nil, fmt.Errorf("error determining remote addr: %w", err) - } str, err := sess.OpenStreamSync(ctx) if err != nil {