Skip to content

Commit

Permalink
add support for websocket
Browse files Browse the repository at this point in the history
  • Loading branch information
sukunrt committed Sep 13, 2024
1 parent 235a2da commit 5b7016e
Show file tree
Hide file tree
Showing 11 changed files with 207 additions and 17 deletions.
2 changes: 1 addition & 1 deletion core/network/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ type ConnErrorCode uint32

type ConnError struct {
Remote bool
ErrorCode uint32
ErrorCode ConnErrorCode
}

func (c *ConnError) Error() string {
Expand Down
4 changes: 2 additions & 2 deletions core/network/mux.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,10 @@ type MuxedConn interface {
AcceptStream() (MuxedStream, error)
}

type ConnWithErrorer interface {
type CloseWithErrorer interface {
// CloseWithError closes the connection with errCode. The errCode is sent
// to the peer.
ConnWithError(errCode ConnErrorCode) error
CloseWithError(errCode ConnErrorCode) error
}

// Multiplexer wraps a net.Conn with a stream multiplexing
Expand Down
4 changes: 2 additions & 2 deletions p2p/muxer/yamux/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ func (c *conn) IsClosed() bool {
func (c *conn) OpenStream(ctx context.Context) (network.MuxedStream, error) {
s, err := c.yamux().OpenStream(ctx)
if err != nil {
return nil, err
return nil, parseResetError(err)
}

return (*stream)(s), nil
Expand All @@ -45,7 +45,7 @@ func (c *conn) OpenStream(ctx context.Context) (network.MuxedStream, error) {
// AcceptStream accepts a stream opened by the other side.
func (c *conn) AcceptStream() (network.MuxedStream, error) {
s, err := c.yamux().AcceptStream()
return (*stream)(s), err
return (*stream)(s), parseResetError(err)
}

func (c *conn) yamux() *yamux.Session {
Expand Down
12 changes: 7 additions & 5 deletions p2p/muxer/yamux/stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@ func parseResetError(err error) error {
if err == nil {
return err
}
if errors.Is(err, yamux.ErrStreamReset) {
se := &yamux.StreamError{}
if errors.As(err, &se) {
return &network.StreamError{Remote: se.Remote, ErrorCode: network.StreamErrorCode(se.ErrorCode)}
}
se := &yamux.StreamError{}
if errors.As(err, &se) {
return &network.StreamError{Remote: se.Remote, ErrorCode: network.StreamErrorCode(se.ErrorCode)}
}
ce := &yamux.GoAwayError{}
if errors.As(err, &ce) {
return &network.ConnError{Remote: ce.Remote, ErrorCode: network.ConnErrorCode(ce.ErrorCode)}
}
return err
}
Expand Down
8 changes: 8 additions & 0 deletions p2p/net/swarm/swarm.go
Original file line number Diff line number Diff line change
Expand Up @@ -838,6 +838,14 @@ func (c connWithMetrics) Close() error {
return c.CapableConn.Close()
}

func (c connWithMetrics) CloseWithError(errCode network.ConnErrorCode) error {
c.metricsTracer.ClosedConnection(c.dir, time.Since(c.opened), c.ConnState(), c.LocalMultiaddr())
if ce, ok := c.CapableConn.(network.CloseWithErrorer); ok {
return ce.CloseWithError(errCode)
}
return c.CapableConn.Close()
}

func (c connWithMetrics) Stat() network.ConnStats {
if cs, ok := c.CapableConn.(network.ConnStat); ok {
return cs.Stat()
Expand Down
4 changes: 1 addition & 3 deletions p2p/net/swarm/swarm_conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,7 @@ func (c *Conn) doClose(errCode network.ConnErrorCode) {
c.streams.Unlock()

if errCode != 0 {
if ce, ok := c.conn.(interface {
CloseWithError(network.ConnErrorCode) error
}); ok {
if ce, ok := c.conn.(network.CloseWithErrorer); ok {
c.err = ce.CloseWithError(errCode)
} else {
c.err = c.conn.Close()
Expand Down
7 changes: 7 additions & 0 deletions p2p/net/upgrader/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,10 @@ func (t *transportConn) ConnState() network.ConnectionState {
UsedEarlyMuxerNegotiation: t.usedEarlyMuxerNegotiation,
}
}

func (t *transportConn) CloseWithError(errCode network.ConnErrorCode) error {
if ce, ok := t.MuxedConn.(network.CloseWithErrorer); ok {
return ce.CloseWithError(errCode)
}
return t.Close()
}
162 changes: 160 additions & 2 deletions p2p/test/transport/transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -804,8 +804,8 @@ func TestConnClosedWhenRemoteCloses(t *testing.T) {
func TestStreamErrorCode(t *testing.T) {
for _, tc := range transportsToTest {
t.Run(tc.Name, func(t *testing.T) {
if tc.Name != "QUIC" && tc.Name != "TCP / TLS / Yamux" && tc.Name != "WebRTC" {
t.Skipf("skipping: %s, only implemented for QUIC", tc.Name)
if tc.Name == "WebTransport" {
t.Skipf("skipping: %s, not implemented", tc.Name)
return
}
server := tc.HostGenerator(t, TransportTestCaseOpts{})
Expand Down Expand Up @@ -841,6 +841,9 @@ func TestStreamErrorCode(t *testing.T) {
}
_, err = s.Read(b)
errCh <- err

_, err = s.Write(b)
errCh <- err
})
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
Expand All @@ -865,8 +868,163 @@ func TestStreamErrorCode(t *testing.T) {
_, err = s.Write(buf)
checkError(err, 42, false)

err = <-errCh // read error
checkError(err, 42, true)

err = <-errCh // write error
checkError(err, 42, true)
})
}
}

// TestStreamErrorCodeConnClosed tests correctness for resetting stream with errors
func TestStreamErrorCodeConnClosed(t *testing.T) {
for _, tc := range transportsToTest {
t.Run(tc.Name, func(t *testing.T) {
if tc.Name == "WebTransport" || tc.Name == "WebRTC" {
t.Skipf("skipping: %s, not implemented", tc.Name)
return
}
server := tc.HostGenerator(t, TransportTestCaseOpts{})
client := tc.HostGenerator(t, TransportTestCaseOpts{NoListen: true})
defer server.Close()
defer client.Close()

checkError := func(err error, code network.ConnErrorCode, remote bool) {
t.Helper()
if err == nil {
t.Fatal("expected non nil error")
}
ce := &network.ConnError{}
if errors.As(err, &ce) {
require.Equal(t, code, ce.ErrorCode)
require.Equal(t, remote, ce.Remote)
return
}
t.Fatal("expected network.ConnError, got:", err)
}

errCh := make(chan error)
server.SetStreamHandler("/test", func(s network.Stream) {
defer s.Reset()
b := make([]byte, 10)
n, err := s.Read(b)
if !assert.NoError(t, err) {
return
}
_, err = s.Write(b[:n])
if !assert.NoError(t, err) {
return
}
_, err = s.Read(b)
errCh <- err

_, err = s.Write(b)
errCh <- err
})
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
client.Peerstore().AddAddrs(server.ID(), server.Addrs(), peerstore.PermanentAddrTTL)
s, err := client.NewStream(ctx, server.ID(), "/test")
require.NoError(t, err)

_, err = s.Write([]byte("hello"))
require.NoError(t, err)

buf := make([]byte, 10)
n, err := s.Read(buf)
require.NoError(t, err)
require.Equal(t, []byte("hello"), buf[:n])

err = s.Conn().CloseWithError(42)
require.NoError(t, err)

_, err = s.Read(buf)
checkError(err, 42, false)

_, err = s.Write(buf)
checkError(err, 42, false)

err = <-errCh
checkError(err, 42, true)

err = <-errCh
checkError(err, 42, true)
})
}
}

// TestConnectionErrorCode tests correctness for resetting stream with errors
func TestConnectionErrorCode(t *testing.T) {
for _, tc := range transportsToTest {
t.Run(tc.Name, func(t *testing.T) {
if tc.Name == "WebTransport" || tc.Name == "WebRTC" {
t.Skipf("skipping: %s, not implemented", tc.Name)
return
}
server := tc.HostGenerator(t, TransportTestCaseOpts{})
client := tc.HostGenerator(t, TransportTestCaseOpts{NoListen: true})
defer server.Close()
defer client.Close()

checkError := func(err error, code network.ConnErrorCode, remote bool) {
t.Helper()
if err == nil {
t.Fatal("expected non nil error")
}
ce := &network.ConnError{}
if errors.As(err, &ce) {
require.Equal(t, code, ce.ErrorCode)
require.Equal(t, remote, ce.Remote)
return
}
t.Fatal("expected network.ConnError, got:", err)
}

errCh := make(chan error)
server.SetStreamHandler("/test", func(s network.Stream) {
defer s.Reset()
b := make([]byte, 10)
n, err := s.Read(b)
if !assert.NoError(t, err) {
return
}
_, err = s.Write(b[:n])
if !assert.NoError(t, err) {
return
}

_, err = s.Read(b)
if !assert.Error(t, err) {
return
}
_, err = s.Conn().NewStream(context.Background())
errCh <- err
})
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
client.Peerstore().AddAddrs(server.ID(), server.Addrs(), peerstore.PermanentAddrTTL)
s, err := client.NewStream(ctx, server.ID(), "/test")
require.NoError(t, err)

_, err = s.Write([]byte("hello"))
require.NoError(t, err)

buf := make([]byte, 10)
n, err := s.Read(buf)
require.NoError(t, err)
require.Equal(t, []byte("hello"), buf[:n])

err = s.Conn().CloseWithError(42)
require.NoError(t, err)

str, err := s.Conn().NewStream(context.Background())
require.Nil(t, str)
checkError(err, 42, false)

err = <-errCh
checkError(err, 42, true)

})
}
}
4 changes: 2 additions & 2 deletions p2p/transport/quic/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ func (c *conn) allowWindowIncrease(size uint64) bool {
func (c *conn) OpenStream(ctx context.Context) (network.MuxedStream, error) {
qstr, err := c.quicConn.OpenStreamSync(ctx)
if err != nil {
return nil, err
return nil, parseStreamError(err)
}
return &stream{Stream: qstr}, nil
}
Expand All @@ -70,7 +70,7 @@ func (c *conn) OpenStream(ctx context.Context) (network.MuxedStream, error) {
func (c *conn) AcceptStream() (network.MuxedStream, error) {
qstr, err := c.quicConn.AcceptStream(context.Background())
if err != nil {
return nil, err
return nil, parseStreamError(err)
}
return &stream{Stream: qstr}, nil
}
Expand Down
9 changes: 9 additions & 0 deletions p2p/transport/quic/stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,22 @@ func parseStreamError(err error) error {
if errors.As(err, &se) {
code := se.ErrorCode
if code > math.MaxUint32 {
// TODO(sukunrt): do we need this?
code = reset
}
err = &network.StreamError{
ErrorCode: network.StreamErrorCode(code),
Remote: se.Remote,
}
}
ae := &quic.ApplicationError{}
if errors.As(err, &ae) {
code := ae.ErrorCode
err = &network.ConnError{
ErrorCode: network.ConnErrorCode(code),
Remote: ae.Remote,
}
}
return err
}

Expand Down
8 changes: 8 additions & 0 deletions p2p/transport/websocket/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -162,3 +162,11 @@ func (c *capableConn) ConnState() network.ConnectionState {
cs.Transport = "websocket"
return cs
}

// CloseWithError implements network.CloseWithErrorer
func (c *capableConn) CloseWithError(errCode network.ConnErrorCode) error {
if ce, ok := c.CapableConn.(network.CloseWithErrorer); ok {
return ce.CloseWithError(errCode)
}
return c.Close()
}

0 comments on commit 5b7016e

Please sign in to comment.