diff --git a/go.mod b/go.mod index c6f3f2b324..38caff5d85 100644 --- a/go.mod +++ b/go.mod @@ -38,7 +38,7 @@ require ( github.com/mikioh/tcpinfo v0.0.0-20190314235526-30a79bb1804b github.com/mr-tron/base58 v1.2.0 github.com/multiformats/go-base32 v0.1.0 - github.com/multiformats/go-multiaddr v0.13.0 + github.com/multiformats/go-multiaddr v0.14.0 github.com/multiformats/go-multiaddr-dns v0.4.0 github.com/multiformats/go-multiaddr-fmt v0.1.0 github.com/multiformats/go-multibase v0.2.0 diff --git a/go.sum b/go.sum index 4650eef1f0..082c64848f 100644 --- a/go.sum +++ b/go.sum @@ -233,8 +233,8 @@ github.com/multiformats/go-base32 v0.1.0/go.mod h1:Kj3tFY6zNr+ABYMqeUNeGvkIC/UYg github.com/multiformats/go-base36 v0.2.0 h1:lFsAbNOGeKtuKozrtBsAkSVhv1p9D0/qedU9rQyccr0= github.com/multiformats/go-base36 v0.2.0/go.mod h1:qvnKE++v+2MWCfePClUEjE78Z7P2a1UV0xHgWc0hkp4= github.com/multiformats/go-multiaddr v0.1.1/go.mod h1:aMKBKNEYmzmDmxfX88/vz+J5IU55txyt0p4aiWVohjo= -github.com/multiformats/go-multiaddr v0.13.0 h1:BCBzs61E3AGHcYYTv8dqRH43ZfyrqM8RXVPT8t13tLQ= -github.com/multiformats/go-multiaddr v0.13.0/go.mod h1:sBXrNzucqkFJhvKOiwwLyqamGa/P5EIXNPLovyhQCII= +github.com/multiformats/go-multiaddr v0.14.0 h1:bfrHrJhrRuh/NXH5mCnemjpbGjzRw/b+tJFOD41g2tU= +github.com/multiformats/go-multiaddr v0.14.0/go.mod h1:6EkVAxtznq2yC3QT5CM1UTAwG0GTP3EWAIcjHuzQ+r4= github.com/multiformats/go-multiaddr-dns v0.4.0 h1:P76EJ3qzBXpUXZ3twdCDx/kvagMsNo0LMFXpyms/zgU= github.com/multiformats/go-multiaddr-dns v0.4.0/go.mod h1:7hfthtB4E4pQwirrz+J0CcDUfbWzTqEzVyYKKIKpgkc= github.com/multiformats/go-multiaddr-fmt v0.1.0 h1:WLEFClPycPkp4fnIzoFoV9FVd49/eQsuaL3/CWe167E= diff --git a/p2p/test/transport/transport_test.go b/p2p/test/transport/transport_test.go index 60f8ca0c06..e353ba6526 100644 --- a/p2p/test/transport/transport_test.go +++ b/p2p/test/transport/transport_test.go @@ -31,6 +31,7 @@ import ( "github.com/libp2p/go-libp2p/p2p/protocol/ping" "github.com/libp2p/go-libp2p/p2p/security/noise" tls "github.com/libp2p/go-libp2p/p2p/security/tls" + libp2pmemory "github.com/libp2p/go-libp2p/p2p/transport/memory" libp2pwebrtc "github.com/libp2p/go-libp2p/p2p/transport/webrtc" "go.uber.org/mock/gomock" @@ -100,27 +101,9 @@ var transportsToTest = []TransportTestCase{ }, }, { - Name: "TCP-Shared / TLS / Yamux", - HostGenerator: func(t *testing.T, opts TransportTestCaseOpts) host.Host { - libp2pOpts := transformOpts(opts) - libp2pOpts = append(libp2pOpts, libp2p.ShareTCPListener()) - libp2pOpts = append(libp2pOpts, libp2p.Security(tls.ID, tls.New)) - libp2pOpts = append(libp2pOpts, libp2p.Muxer(yamux.ID, yamux.DefaultTransport)) - if opts.NoListen { - libp2pOpts = append(libp2pOpts, libp2p.NoListenAddrs) - } else { - libp2pOpts = append(libp2pOpts, libp2p.ListenAddrStrings("/ip4/127.0.0.1/tcp/0")) - } - h, err := libp2p.New(libp2pOpts...) - require.NoError(t, err) - return h - }, - }, - { - Name: "WebSocket-Shared", + Name: "WebSocket", HostGenerator: func(t *testing.T, opts TransportTestCaseOpts) host.Host { libp2pOpts := transformOpts(opts) - libp2pOpts = append(libp2pOpts, libp2p.ShareTCPListener()) if opts.NoListen { libp2pOpts = append(libp2pOpts, libp2p.NoListenAddrs) } else { @@ -132,13 +115,13 @@ var transportsToTest = []TransportTestCase{ }, }, { - Name: "WebSocket", + Name: "QUIC", HostGenerator: func(t *testing.T, opts TransportTestCaseOpts) host.Host { libp2pOpts := transformOpts(opts) if opts.NoListen { libp2pOpts = append(libp2pOpts, libp2p.NoListenAddrs) } else { - libp2pOpts = append(libp2pOpts, libp2p.ListenAddrStrings("/ip4/127.0.0.1/tcp/0/ws")) + libp2pOpts = append(libp2pOpts, libp2p.ListenAddrStrings("/ip4/127.0.0.1/udp/0/quic-v1")) } h, err := libp2p.New(libp2pOpts...) require.NoError(t, err) @@ -146,13 +129,13 @@ var transportsToTest = []TransportTestCase{ }, }, { - Name: "QUIC", + Name: "WebTransport", HostGenerator: func(t *testing.T, opts TransportTestCaseOpts) host.Host { libp2pOpts := transformOpts(opts) if opts.NoListen { libp2pOpts = append(libp2pOpts, libp2p.NoListenAddrs) } else { - libp2pOpts = append(libp2pOpts, libp2p.ListenAddrStrings("/ip4/127.0.0.1/udp/0/quic-v1")) + libp2pOpts = append(libp2pOpts, libp2p.ListenAddrStrings("/ip4/127.0.0.1/udp/0/quic-v1/webtransport")) } h, err := libp2p.New(libp2pOpts...) require.NoError(t, err) @@ -160,13 +143,14 @@ var transportsToTest = []TransportTestCase{ }, }, { - Name: "WebTransport", + Name: "WebRTC", HostGenerator: func(t *testing.T, opts TransportTestCaseOpts) host.Host { libp2pOpts := transformOpts(opts) + libp2pOpts = append(libp2pOpts, libp2p.Transport(libp2pwebrtc.New)) if opts.NoListen { libp2pOpts = append(libp2pOpts, libp2p.NoListenAddrs) } else { - libp2pOpts = append(libp2pOpts, libp2p.ListenAddrStrings("/ip4/127.0.0.1/udp/0/quic-v1/webtransport")) + libp2pOpts = append(libp2pOpts, libp2p.ListenAddrStrings("/ip4/127.0.0.1/udp/0/webrtc-direct")) } h, err := libp2p.New(libp2pOpts...) require.NoError(t, err) @@ -174,14 +158,14 @@ var transportsToTest = []TransportTestCase{ }, }, { - Name: "WebRTC", + Name: "Memory", HostGenerator: func(t *testing.T, opts TransportTestCaseOpts) host.Host { libp2pOpts := transformOpts(opts) - libp2pOpts = append(libp2pOpts, libp2p.Transport(libp2pwebrtc.New)) + libp2pOpts = append(libp2pOpts, libp2p.Transport(libp2pmemory.NewTransport)) if opts.NoListen { libp2pOpts = append(libp2pOpts, libp2p.NoListenAddrs) } else { - libp2pOpts = append(libp2pOpts, libp2p.ListenAddrStrings("/ip4/127.0.0.1/udp/0/webrtc-direct")) + libp2pOpts = append(libp2pOpts, libp2p.ListenAddrStrings("/memory/1234")) } h, err := libp2p.New(libp2pOpts...) require.NoError(t, err) diff --git a/p2p/transport/memory/conn.go b/p2p/transport/memory/conn.go new file mode 100644 index 0000000000..61c81462cc --- /dev/null +++ b/p2p/transport/memory/conn.go @@ -0,0 +1,143 @@ +package memory + +import ( + "context" + "sync" + "sync/atomic" + + ic "github.com/libp2p/go-libp2p/core/crypto" + "github.com/libp2p/go-libp2p/core/network" + "github.com/libp2p/go-libp2p/core/peer" + tpt "github.com/libp2p/go-libp2p/core/transport" + ma "github.com/multiformats/go-multiaddr" +) + +type conn struct { + id int64 + rconn *conn + + scope network.ConnManagementScope + listener *listener + transport *transport + + localPeer peer.ID + localMultiaddr ma.Multiaddr + + remotePeerID peer.ID + remotePubKey ic.PubKey + remoteMultiaddr ma.Multiaddr + + mu sync.Mutex + + closed atomic.Bool + closeOnce sync.Once + + streamC chan *stream + streams map[int64]network.MuxedStream +} + +var _ tpt.CapableConn = &conn{} + +func newConnection( + t *transport, + s *stream, + localPeer peer.ID, + localMultiaddr ma.Multiaddr, + remotePubKey ic.PubKey, + remotePeer peer.ID, + remoteMultiaddr ma.Multiaddr, +) *conn { + c := &conn{ + id: connCounter.Add(1), + transport: t, + localPeer: localPeer, + localMultiaddr: localMultiaddr, + remotePubKey: remotePubKey, + remotePeerID: remotePeer, + remoteMultiaddr: remoteMultiaddr, + streamC: make(chan *stream, 1), + streams: make(map[int64]network.MuxedStream), + } + + c.addStream(s.id, s) + return c +} + +func (c *conn) Close() error { + c.closeOnce.Do(func() { + c.closed.Store(true) + go c.rconn.Close() + c.teardown() + }) + + return nil +} + +func (c *conn) IsClosed() bool { + return c.closed.Load() +} + +func (c *conn) OpenStream(ctx context.Context) (network.MuxedStream, error) { + sl, sr := newStreamPair() + + c.streamC <- sr + sl.conn = c + c.addStream(sl.id, sl) + return sl, nil +} + +func (c *conn) AcceptStream() (network.MuxedStream, error) { + in := <-c.streamC + in.conn = c + c.addStream(in.id, in) + return in, nil +} + +func (c *conn) LocalPeer() peer.ID { return c.localPeer } + +// RemotePeer returns the peer ID of the remote peer. +func (c *conn) RemotePeer() peer.ID { return c.remotePeerID } + +// RemotePublicKey returns the public pkey of the remote peer. +func (c *conn) RemotePublicKey() ic.PubKey { return c.remotePubKey } + +// LocalMultiaddr returns the local Multiaddr associated +func (c *conn) LocalMultiaddr() ma.Multiaddr { return c.localMultiaddr } + +// RemoteMultiaddr returns the remote Multiaddr associated +func (c *conn) RemoteMultiaddr() ma.Multiaddr { return c.remoteMultiaddr } + +func (c *conn) Transport() tpt.Transport { + return c.transport +} + +func (c *conn) Scope() network.ConnScope { + return c.scope +} + +// ConnState is the state of security connection. +func (c *conn) ConnState() network.ConnectionState { + return network.ConnectionState{Transport: "memory"} +} + +func (c *conn) addStream(id int64, stream network.MuxedStream) { + c.mu.Lock() + defer c.mu.Unlock() + + c.streams[id] = stream +} + +func (c *conn) removeStream(id int64) { + c.mu.Lock() + defer c.mu.Unlock() + + delete(c.streams, id) +} + +func (c *conn) teardown() { + for _, s := range c.streams { + s.Reset() + } + + // TODO: remove self from listener +} diff --git a/p2p/transport/memory/hub.go b/p2p/transport/memory/hub.go new file mode 100644 index 0000000000..55b85ccbad --- /dev/null +++ b/p2p/transport/memory/hub.go @@ -0,0 +1,80 @@ +package memory + +import ( + ic "github.com/libp2p/go-libp2p/core/crypto" + "github.com/libp2p/go-libp2p/core/peer" + ma "github.com/multiformats/go-multiaddr" + mafmt "github.com/multiformats/go-multiaddr-fmt" + "sync" + "sync/atomic" +) + +var ( + connCounter atomic.Int64 + streamCounter atomic.Int64 + listenerCounter atomic.Int64 + dialMatcher = mafmt.Base(ma.P_MEMORY) + memhub = newHub() +) + +type hub struct { + mu sync.RWMutex + closeOnce sync.Once + pubKeys map[peer.ID]ic.PubKey + listeners map[string]*listener +} + +func newHub() *hub { + return &hub{ + pubKeys: make(map[peer.ID]ic.PubKey), + listeners: make(map[string]*listener), + } +} + +func (h *hub) addListener(addr string, l *listener) { + h.mu.Lock() + defer h.mu.Unlock() + + h.listeners[addr] = l +} + +func (h *hub) removeListener(addr string, l *listener) { + h.mu.Lock() + defer h.mu.Unlock() + + delete(h.listeners, addr) +} + +func (h *hub) getListener(addr string) (*listener, bool) { + h.mu.RLock() + defer h.mu.RUnlock() + + l, ok := h.listeners[addr] + return l, ok +} + +func (h *hub) addPubKey(p peer.ID, pk ic.PubKey) { + h.mu.Lock() + defer h.mu.Unlock() + + h.pubKeys[p] = pk +} + +func (h *hub) getPubKey(p peer.ID) (ic.PubKey, bool) { + h.mu.RLock() + defer h.mu.RUnlock() + + pk, ok := h.pubKeys[p] + return pk, ok +} + +func (h *hub) close() { + h.closeOnce.Do(func() { + h.mu.Lock() + defer h.mu.Unlock() + + for _, l := range h.listeners { + l.Close() + } + }) +} diff --git a/p2p/transport/memory/listener.go b/p2p/transport/memory/listener.go new file mode 100644 index 0000000000..81d5e484d0 --- /dev/null +++ b/p2p/transport/memory/listener.go @@ -0,0 +1,75 @@ +package memory + +import ( + "context" + "net" + "sync" + + tpt "github.com/libp2p/go-libp2p/core/transport" + ma "github.com/multiformats/go-multiaddr" +) + +const ( + listenerQueueSize = 16 +) + +type listener struct { + id int64 + + t *transport + ctx context.Context + cancel context.CancelFunc + laddr ma.Multiaddr + + mu sync.Mutex + connCh chan *conn + connections map[int64]*conn +} + +func (l *listener) Multiaddr() ma.Multiaddr { + return l.laddr +} + +func newListener(t *transport, laddr ma.Multiaddr) *listener { + ctx, cancel := context.WithCancel(context.Background()) + return &listener{ + id: listenerCounter.Add(1), + t: t, + ctx: ctx, + cancel: cancel, + laddr: laddr, + connCh: make(chan *conn, listenerQueueSize), + connections: make(map[int64]*conn), + } +} + +// Accept accepts new connections. +func (l *listener) Accept() (tpt.CapableConn, error) { + select { + case <-l.ctx.Done(): + return nil, tpt.ErrListenerClosed + case c, ok := <-l.connCh: + if !ok { + return nil, tpt.ErrListenerClosed + } + + l.mu.Lock() + defer l.mu.Unlock() + + c.listener = l + c.transport = l.t + l.connections[c.id] = c + return c, nil + } +} + +// Close closes the listener. +func (l *listener) Close() error { + l.cancel() + return nil +} + +// Addr returns the address of this listener. +func (l *listener) Addr() net.Addr { + return nil +} diff --git a/p2p/transport/memory/stream.go b/p2p/transport/memory/stream.go new file mode 100644 index 0000000000..d196aea0fc --- /dev/null +++ b/p2p/transport/memory/stream.go @@ -0,0 +1,202 @@ +package memory + +import ( + "bytes" + "errors" + "io" + "net" + "sync" + "sync/atomic" + "time" + + "github.com/libp2p/go-libp2p/core/network" +) + +// stream implements network.Stream +type stream struct { + id int64 + conn *conn + + wrMu sync.Mutex // Serialize Write operations + buf *bytes.Buffer // Buffer for partial reads + + // Used by local Read to interact with remote Write. + rdRx <-chan []byte + + // Used by local Write to interact with remote Read. + wrTx chan<- []byte + + once sync.Once // Protects closing localDone + localDone chan struct{} + remoteDone <-chan struct{} + + reset chan struct{} + close chan struct{} + readClosed atomic.Bool + writeClosed atomic.Bool +} + +var ErrClosed = errors.New("stream closed") + +func newStreamPair() (*stream, *stream) { + cb1 := make(chan []byte, 1) + cb2 := make(chan []byte, 1) + + done1 := make(chan struct{}) + done2 := make(chan struct{}) + + sa := newStream(cb1, cb2, done1, done2) + sb := newStream(cb2, cb1, done2, done1) + + return sa, sb +} + +func newStream(rdRx <-chan []byte, wrTx chan<- []byte, localDone chan struct{}, remoteDone <-chan struct{}) *stream { + s := &stream{ + rdRx: rdRx, + wrTx: wrTx, + buf: new(bytes.Buffer), + localDone: localDone, + remoteDone: remoteDone, + reset: make(chan struct{}, 1), + close: make(chan struct{}, 1), + } + + return s +} + +func (p *stream) Write(b []byte) (int, error) { + if p.writeClosed.Load() { + return 0, ErrClosed + } + + n, err := p.write(b) + if err != nil && err != io.ErrClosedPipe { + err = &net.OpError{Op: "write", Net: "pipe", Err: err} + } + return n, err +} + +func (p *stream) write(b []byte) (n int, err error) { + switch { + case isClosedChan(p.localDone): + return 0, io.ErrClosedPipe + case isClosedChan(p.remoteDone): + return 0, io.ErrClosedPipe + } + + p.wrMu.Lock() // Ensure entirety of b is written together + defer p.wrMu.Unlock() + + select { + case <-p.close: + return n, ErrClosed + case <-p.reset: + return n, network.ErrReset + case p.wrTx <- b: + n += len(b) + case <-p.localDone: + return n, io.ErrClosedPipe + case <-p.remoteDone: + return n, io.ErrClosedPipe + } + + return n, nil +} + +func (p *stream) Read(b []byte) (int, error) { + if p.readClosed.Load() { + return 0, ErrClosed + } + + n, err := p.read(b) + if err != nil && err != io.EOF && err != io.ErrClosedPipe { + err = &net.OpError{Op: "read", Net: "pipe", Err: err} + } + + return n, err +} + +func (p *stream) read(b []byte) (n int, err error) { + switch { + case isClosedChan(p.localDone): + return 0, io.ErrClosedPipe + case isClosedChan(p.remoteDone): + return 0, io.EOF + } + + select { + case <-p.reset: + return n, network.ErrReset + case bw, ok := <-p.rdRx: + if !ok { + p.readClosed.Store(true) + return 0, io.EOF + } + + p.buf.Write(bw) + case <-p.localDone: + return 0, io.ErrClosedPipe + case <-p.remoteDone: + return 0, io.EOF + default: + n, err = p.buf.Read(b) + } + + return n, err +} + +func (s *stream) CloseWrite() error { + select { + case s.close <- struct{}{}: + default: + } + + s.writeClosed.Store(true) + return nil +} + +func (s *stream) CloseRead() error { + s.readClosed.Store(true) + return nil +} + +func (s *stream) Close() error { + _ = s.CloseRead() + return s.CloseWrite() +} + +func (s *stream) Reset() error { + select { + case s.reset <- struct{}{}: + default: + } + + s.once.Do(func() { + close(s.localDone) + }) + + // No meaningful error case here. + return nil +} + +func (s *stream) SetDeadline(t time.Time) error { + return &net.OpError{Op: "set", Net: "pipe", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} +} + +func (s *stream) SetReadDeadline(t time.Time) error { + return &net.OpError{Op: "set", Net: "pipe", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} +} + +func (s *stream) SetWriteDeadline(t time.Time) error { + return &net.OpError{Op: "set", Net: "pipe", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} +} + +func isClosedChan(c <-chan struct{}) bool { + select { + case <-c: + return true + default: + return false + } +} diff --git a/p2p/transport/memory/stream_test.go b/p2p/transport/memory/stream_test.go new file mode 100644 index 0000000000..f154bbbcdd --- /dev/null +++ b/p2p/transport/memory/stream_test.go @@ -0,0 +1,67 @@ +package memory + +import ( + "github.com/stretchr/testify/require" + "io" + "testing" +) + +func TestStreamSimpleReadWriteClose(t *testing.T) { + t.Parallel() + streamLocal, streamRemote := newStreamPair() + + // send a foobar from the client + n, err := streamLocal.Write([]byte("foobar")) + require.NoError(t, err) + require.Equal(t, 6, n) + require.NoError(t, streamLocal.CloseWrite()) + + // writing after closing should error + _, err = streamLocal.Write([]byte("foobar")) + require.Error(t, err) + + // now read all the data on the server side + b, err := io.ReadAll(streamRemote) + require.NoError(t, err) + require.Equal(t, []byte("foobar"), b) + + // reading again should give another io.EOF + n, err = streamRemote.Read(make([]byte, 10)) + require.Zero(t, n) + require.ErrorIs(t, err, io.EOF) + + // send something back + _, err = streamRemote.Write([]byte("lorem ipsum")) + require.NoError(t, err) + require.NoError(t, streamRemote.CloseWrite()) + + // and read it at the client + b, err = io.ReadAll(streamLocal) + require.NoError(t, err) + require.Equal(t, []byte("lorem ipsum"), b) + + // stream is only cleaned up on calling Close or Reset + require.NoError(t, streamLocal.Close()) + require.NoError(t, streamRemote.Close()) +} + +func TestStreamPartialReads(t *testing.T) { + t.Parallel() + streamLocal, streamRemote := newStreamPair() + + _, err := streamRemote.Write([]byte("foobar")) + require.NoError(t, err) + require.NoError(t, streamRemote.CloseWrite()) + + n, err := streamLocal.Read([]byte{}) // empty read + require.NoError(t, err) + require.Zero(t, n) + b := make([]byte, 3) + n, err = streamLocal.Read(b) + require.Equal(t, 3, n) + require.NoError(t, err) + require.Equal(t, []byte("foo"), b) + b, err = io.ReadAll(streamLocal) + require.NoError(t, err) + require.Equal(t, []byte("bar"), b) +} diff --git a/p2p/transport/memory/transport.go b/p2p/transport/memory/transport.go new file mode 100644 index 0000000000..39d152e7d3 --- /dev/null +++ b/p2p/transport/memory/transport.go @@ -0,0 +1,144 @@ +package memory + +import ( + "context" + "errors" + ic "github.com/libp2p/go-libp2p/core/crypto" + "github.com/libp2p/go-libp2p/core/network" + "github.com/libp2p/go-libp2p/core/peer" + "github.com/libp2p/go-libp2p/core/pnet" + tpt "github.com/libp2p/go-libp2p/core/transport" + ma "github.com/multiformats/go-multiaddr" + "sync" +) + +type transport struct { + psk pnet.PSK + rcmgr network.ResourceManager + localPeerID peer.ID + localPrivKey ic.PrivKey + localPubKey ic.PubKey + + mu sync.RWMutex + + connections map[int64]*conn +} + +func NewTransport(privKey ic.PrivKey, psk pnet.PSK, rcmgr network.ResourceManager) (tpt.Transport, error) { + if rcmgr == nil { + rcmgr = &network.NullResourceManager{} + } + + id, err := peer.IDFromPrivateKey(privKey) + if err != nil { + return nil, err + } + + memhub.addPubKey(id, privKey.GetPublic()) + return &transport{ + psk: psk, + rcmgr: rcmgr, + localPeerID: id, + localPrivKey: privKey, + localPubKey: privKey.GetPublic(), + connections: make(map[int64]*conn), + }, nil +} + +func (t *transport) Dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (tpt.CapableConn, error) { + scope, err := t.rcmgr.OpenConnection(network.DirOutbound, false, raddr) + if err != nil { + return nil, err + } + + c, err := t.dialWithScope(ctx, raddr, p, scope) + if err != nil { + return nil, err + } + + return c, nil +} + +func (t *transport) dialWithScope(_ context.Context, raddr ma.Multiaddr, rpid peer.ID, scope network.ConnManagementScope) (tpt.CapableConn, error) { + if err := scope.SetPeer(rpid); err != nil { + return nil, err + } + + rl, ok := memhub.getListener(raddr.String()) + if !ok { + return nil, errors.New("failed to get listener") + } + + remotePubKey, ok := memhub.getPubKey(rpid) + if !ok { + return nil, errors.New("failed to get remote public key") + } + + lc, rc := t.newConnPair(remotePubKey, rpid, raddr) + + rl.connCh <- rc + return lc, nil +} + +func (t *transport) CanDial(addr ma.Multiaddr) bool { + return dialMatcher.Matches(addr) +} + +func (t *transport) Listen(laddr ma.Multiaddr) (tpt.Listener, error) { + // TODO: Check if we need to add scope via conn mngr + l := newListener(t, laddr) + memhub.addListener(laddr.String(), l) + + return l, nil +} + +func (t *transport) Proxy() bool { + return false +} + +// Protocols returns the set of protocols handled by this transport. +func (t *transport) Protocols() []int { + return []int{ma.P_MEMORY} +} + +func (t *transport) String() string { + return "MemoryTransport" +} + +func (t *transport) Close() error { + // TODO: Go trough all listeners and close them + t.mu.Lock() + defer t.mu.Unlock() + + for _, c := range t.connections { + c.Close() + //delete(t.connections, c.id) + } + + return nil +} + +func (t *transport) addConn(c *conn) { + t.mu.Lock() + defer t.mu.Unlock() + + t.connections[c.id] = c +} + +func (t *transport) removeConn(c *conn) { + t.mu.Lock() + defer t.mu.Unlock() + + delete(t.connections, c.id) +} + +func (t *transport) newConnPair(remotePubKey ic.PubKey, rpid peer.ID, raddr ma.Multiaddr) (*conn, *conn) { + sl, sr := newStreamPair() + + lc := newConnection(t, sl, t.localPeerID, nil, remotePubKey, rpid, raddr) + rc := newConnection(nil, sr, rpid, raddr, t.localPubKey, t.localPeerID, nil) + + lc.rconn = rc + rc.rconn = lc + return lc, rc +} diff --git a/p2p/transport/memory/transport_test.go b/p2p/transport/memory/transport_test.go new file mode 100644 index 0000000000..f17835f3ff --- /dev/null +++ b/p2p/transport/memory/transport_test.go @@ -0,0 +1,72 @@ +package memory + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "io" + "testing" + + ic "github.com/libp2p/go-libp2p/core/crypto" + tpt "github.com/libp2p/go-libp2p/core/transport" + + ma "github.com/multiformats/go-multiaddr" + "github.com/stretchr/testify/require" +) + +func getTransport(t *testing.T) tpt.Transport { + t.Helper() + rsaKey, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + key, err := ic.UnmarshalRsaPrivateKey(x509.MarshalPKCS1PrivateKey(rsaKey)) + require.NoError(t, err) + tr, err := NewTransport(key, nil, nil) + require.NoError(t, err) + return tr +} + +func TestMemoryProtocol(t *testing.T) { + t.Parallel() + tr := getTransport(t) + defer tr.(io.Closer).Close() + + protocols := tr.Protocols() + if len(protocols) > 1 { + t.Fatalf("expected at most one protocol, got %v", protocols) + } + + if protocols[0] != ma.P_MEMORY { + t.Fatalf("expected the supported protocol to be memory, got %d", protocols[0]) + } +} + +func TestCanDial(t *testing.T) { + t.Parallel() + tr := getTransport(t) + defer tr.(io.Closer).Close() + + invalid := []string{ + "/ip4/127.0.0.1/udp/1234", + "/ip4/5.5.5.5/tcp/1234", + "/dns/google.com/udp/443/quic-v1", + "/ip4/127.0.0.1/udp/1234/quic", + } + valid := []string{ + "/memory/1234", + "/memory/1337123", + } + for _, s := range invalid { + invalidAddr, err := ma.NewMultiaddr(s) + require.NoError(t, err) + if tr.CanDial(invalidAddr) { + t.Errorf("didn't expect to be able to dial a non-memory address (%s)", invalidAddr) + } + } + for _, s := range valid { + validAddr, err := ma.NewMultiaddr(s) + require.NoError(t, err) + if !tr.CanDial(validAddr) { + t.Errorf("expected to be able to dial memory address (%s)", validAddr) + } + } +} diff --git a/test-plans/go.mod b/test-plans/go.mod index b2eee27810..e61b3e1889 100644 --- a/test-plans/go.mod +++ b/test-plans/go.mod @@ -7,7 +7,7 @@ toolchain go1.22.1 require ( github.com/go-redis/redis/v8 v8.11.5 github.com/libp2p/go-libp2p v0.0.0 - github.com/multiformats/go-multiaddr v0.13.0 + github.com/multiformats/go-multiaddr v0.14.0 ) require ( diff --git a/test-plans/go.sum b/test-plans/go.sum index cbb839c369..d37b876a6f 100644 --- a/test-plans/go.sum +++ b/test-plans/go.sum @@ -185,6 +185,7 @@ github.com/multiformats/go-base36 v0.2.0/go.mod h1:qvnKE++v+2MWCfePClUEjE78Z7P2a github.com/multiformats/go-multiaddr v0.1.1/go.mod h1:aMKBKNEYmzmDmxfX88/vz+J5IU55txyt0p4aiWVohjo= github.com/multiformats/go-multiaddr v0.13.0 h1:BCBzs61E3AGHcYYTv8dqRH43ZfyrqM8RXVPT8t13tLQ= github.com/multiformats/go-multiaddr v0.13.0/go.mod h1:sBXrNzucqkFJhvKOiwwLyqamGa/P5EIXNPLovyhQCII= +github.com/multiformats/go-multiaddr v0.14.0/go.mod h1:6EkVAxtznq2yC3QT5CM1UTAwG0GTP3EWAIcjHuzQ+r4= github.com/multiformats/go-multiaddr-dns v0.4.0 h1:P76EJ3qzBXpUXZ3twdCDx/kvagMsNo0LMFXpyms/zgU= github.com/multiformats/go-multiaddr-dns v0.4.0/go.mod h1:7hfthtB4E4pQwirrz+J0CcDUfbWzTqEzVyYKKIKpgkc= github.com/multiformats/go-multiaddr-fmt v0.1.0 h1:WLEFClPycPkp4fnIzoFoV9FVd49/eQsuaL3/CWe167E=