Skip to content

Commit

Permalink
Tests and fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
rowanseymour committed Jan 12, 2024
1 parent f7a36fb commit fdb6d17
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 49 deletions.
44 changes: 19 additions & 25 deletions httpx/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,12 @@ const (
pingPeriod = 30 * time.Second
)

var upgrader = websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
}
var upgrader = websocket.Upgrader{ReadBufferSize: 1024, WriteBufferSize: 1024}

type WebSocket interface {
Start()
Send(msg []byte)
Close()
Send([]byte)
Close(int)

OnMessage(fn func([]byte))
OnClose(fn func(int))
Expand All @@ -42,11 +39,12 @@ type socket struct {
readError chan error
writeError chan error
stopWriter chan bool
stopMonitor chan bool
closingWithCode int
rwWaitGroup sync.WaitGroup
monitorWaitGroup sync.WaitGroup
}

// NewWebSocket creates a new web socket from a regular HTTP request
func NewWebSocket(w http.ResponseWriter, r *http.Request, maxReadBytes int64, sendBuffer int) (WebSocket, error) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
Expand All @@ -56,14 +54,13 @@ func NewWebSocket(w http.ResponseWriter, r *http.Request, maxReadBytes int64, se
conn.SetReadLimit(maxReadBytes)

return &socket{
conn: conn,
onMessage: func([]byte) {},
onClose: func(int) {},
outbox: make(chan []byte, sendBuffer),
readError: make(chan error, 1),
writeError: make(chan error, 1),
stopWriter: make(chan bool, 1),
stopMonitor: make(chan bool, 1),
conn: conn,
onMessage: func([]byte) {},
onClose: func(int) {},

Check warning on line 59 in httpx/websocket.go

View check run for this annotation

Codecov / codecov/patch

httpx/websocket.go#L59

Added line #L59 was not covered by tests
outbox: make(chan []byte, sendBuffer),
readError: make(chan error, 1),
writeError: make(chan error, 1),
stopWriter: make(chan bool, 1),
}, nil
}

Expand All @@ -83,10 +80,11 @@ func (s *socket) Send(msg []byte) {
s.outbox <- msg
}

func (s *socket) Close() {
func (s *socket) Close(code int) {
s.closingWithCode = code
s.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(code, ""))
s.conn.Close() // causes reader to stop
s.stopWriter <- true
s.stopMonitor <- true

s.monitorWaitGroup.Wait()
}
Expand All @@ -101,31 +99,27 @@ func (s *socket) monitor() {
s.monitorWaitGroup.Add(1)
defer s.monitorWaitGroup.Done()

closeCode := websocket.CloseNormalClosure

out:
for {
select {
case err := <-s.readError:
if e, ok := err.(*websocket.CloseError); ok {
closeCode = e.Code
if e, ok := err.(*websocket.CloseError); ok && s.closingWithCode == 0 {
s.closingWithCode = e.Code
}
s.stopWriter <- true // ensure writer is stopped
break out
case err := <-s.writeError:
if e, ok := err.(*websocket.CloseError); ok {
closeCode = e.Code
s.closingWithCode = e.Code
}
s.conn.Close() // ensure reader is stopped
break out

Check warning on line 116 in httpx/websocket.go

View check run for this annotation

Codecov / codecov/patch

httpx/websocket.go#L111-L116

Added lines #L111 - L116 were not covered by tests
case <-s.stopMonitor:
break out
}
}

s.rwWaitGroup.Wait()

s.onClose(closeCode)
s.onClose(s.closingWithCode)
}

func (s *socket) reader() {
Expand Down
101 changes: 77 additions & 24 deletions httpx/websocket_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,39 +13,46 @@ import (
"github.com/stretchr/testify/require"
)

func TestSocket(t *testing.T) {
var sock httpx.WebSocket
var err error
func newSocketServer(t *testing.T, fn func(httpx.WebSocket)) string {
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
sock, err := httpx.NewWebSocket(w, r, 4096, 5)
require.NoError(t, err)

fn(sock)
}))

return "ws:" + strings.TrimPrefix(s.URL, "http:")
}

func newSocketConnection(t *testing.T, url string) *websocket.Conn {
d := websocket.Dialer{
Subprotocols: []string{"p1", "p2"},
ReadBufferSize: 1024,
WriteBufferSize: 1024,
HandshakeTimeout: 30 * time.Second,
}
c, _, err := d.Dial(url, nil)
assert.NoError(t, err)
return c
}

func TestSocketMessages(t *testing.T) {
var sock httpx.WebSocket
var serverReceived [][]byte
var serverCloseCode int

server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
sock, err = httpx.NewWebSocket(w, r, 4096, 5)

serverURL := newSocketServer(t, func(ws httpx.WebSocket) {
sock = ws
sock.OnMessage(func(b []byte) {
serverReceived = append(serverReceived, b)
})
sock.OnClose(func(code int) {
serverCloseCode = code
})

sock.Start()
})

require.NoError(t, err)
}))

wsURL := "ws:" + strings.TrimPrefix(server.URL, "http:")

d := websocket.Dialer{
Subprotocols: []string{"p1", "p2"},
ReadBufferSize: 1024,
WriteBufferSize: 1024,
HandshakeTimeout: 30 * time.Second,
}
conn, _, err := d.Dial(wsURL, nil)
assert.NoError(t, err)
assert.NotNil(t, conn)
conn := newSocketConnection(t, serverURL)

// send a message from the server...
sock.Send([]byte("from server"))
Expand All @@ -60,10 +67,56 @@ func TestSocket(t *testing.T) {
conn.WriteMessage(websocket.TextMessage, []byte("to server"))

// and check server received it
time.Sleep(time.Second)
time.Sleep(500 * time.Millisecond)
assert.Equal(t, [][]byte{[]byte("to server")}, serverReceived)

sock.Close()
var connCloseCode int
conn.SetCloseHandler(func(code int, text string) error {
connCloseCode = code
return nil
})

sock.Close(1001)

conn.ReadMessage() // read the close message

assert.Equal(t, 1001, serverCloseCode)
assert.Equal(t, 1001, connCloseCode)
}

func TestSocketClientCloseWithMessage(t *testing.T) {
var sock httpx.WebSocket
var serverCloseCode int

serverURL := newSocketServer(t, func(ws httpx.WebSocket) {
sock = ws
sock.OnClose(func(code int) { serverCloseCode = code })
sock.Start()
})

conn := newSocketConnection(t, serverURL)
conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.ClosePolicyViolation, ""))
conn.Close()

time.Sleep(250 * time.Millisecond)

assert.Equal(t, websocket.ClosePolicyViolation, serverCloseCode)
}

func TestSocketClientCloseWithoutMessage(t *testing.T) {
var sock httpx.WebSocket
var serverCloseCode int

serverURL := newSocketServer(t, func(ws httpx.WebSocket) {
sock = ws
sock.OnClose(func(code int) { serverCloseCode = code })
sock.Start()
})

conn := newSocketConnection(t, serverURL)
conn.Close()

time.Sleep(250 * time.Millisecond)

assert.Equal(t, 1000, serverCloseCode)
assert.Equal(t, websocket.CloseAbnormalClosure, serverCloseCode)
}

0 comments on commit fdb6d17

Please sign in to comment.