diff --git a/httpx/websocket.go b/httpx/websocket.go index 147a75a..e0ece26 100644 --- a/httpx/websocket.go +++ b/httpx/websocket.go @@ -19,15 +19,12 @@ const ( pingPeriod = 30 * time.Second ) -var upgrader = websocket.Upgrader{ - ReadBufferSize: 1024, - WriteBufferSize: 1024, -} +var upgrader = websocket.Upgrader{ReadBufferSize: 1024, WriteBufferSize: 1024} type WebSocket interface { Start() - Send(msg []byte) - Close() + Send([]byte) + Close(int) OnMessage(fn func([]byte)) OnClose(fn func(int)) @@ -42,11 +39,12 @@ type socket struct { readError chan error writeError chan error stopWriter chan bool - stopMonitor chan bool + closingWithCode int rwWaitGroup sync.WaitGroup monitorWaitGroup sync.WaitGroup } +// NewWebSocket creates a new web socket from a regular HTTP request func NewWebSocket(w http.ResponseWriter, r *http.Request, maxReadBytes int64, sendBuffer int) (WebSocket, error) { conn, err := upgrader.Upgrade(w, r, nil) if err != nil { @@ -56,14 +54,13 @@ func NewWebSocket(w http.ResponseWriter, r *http.Request, maxReadBytes int64, se conn.SetReadLimit(maxReadBytes) return &socket{ - conn: conn, - onMessage: func([]byte) {}, - onClose: func(int) {}, - outbox: make(chan []byte, sendBuffer), - readError: make(chan error, 1), - writeError: make(chan error, 1), - stopWriter: make(chan bool, 1), - stopMonitor: make(chan bool, 1), + conn: conn, + onMessage: func([]byte) {}, + onClose: func(int) {}, + outbox: make(chan []byte, sendBuffer), + readError: make(chan error, 1), + writeError: make(chan error, 1), + stopWriter: make(chan bool, 1), }, nil } @@ -83,10 +80,11 @@ func (s *socket) Send(msg []byte) { s.outbox <- msg } -func (s *socket) Close() { +func (s *socket) Close(code int) { + s.closingWithCode = code + s.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(code, "")) s.conn.Close() // causes reader to stop s.stopWriter <- true - s.stopMonitor <- true s.monitorWaitGroup.Wait() } @@ -101,31 +99,27 @@ func (s *socket) monitor() { s.monitorWaitGroup.Add(1) defer s.monitorWaitGroup.Done() - closeCode := websocket.CloseNormalClosure - out: for { select { case err := <-s.readError: - if e, ok := err.(*websocket.CloseError); ok { - closeCode = e.Code + if e, ok := err.(*websocket.CloseError); ok && s.closingWithCode == 0 { + s.closingWithCode = e.Code } s.stopWriter <- true // ensure writer is stopped break out case err := <-s.writeError: if e, ok := err.(*websocket.CloseError); ok { - closeCode = e.Code + s.closingWithCode = e.Code } s.conn.Close() // ensure reader is stopped break out - case <-s.stopMonitor: - break out } } s.rwWaitGroup.Wait() - s.onClose(closeCode) + s.onClose(s.closingWithCode) } func (s *socket) reader() { diff --git a/httpx/websocket_test.go b/httpx/websocket_test.go index 714ccda..a78c119 100644 --- a/httpx/websocket_test.go +++ b/httpx/websocket_test.go @@ -13,39 +13,46 @@ import ( "github.com/stretchr/testify/require" ) -func TestSocket(t *testing.T) { - var sock httpx.WebSocket - var err error +func newSocketServer(t *testing.T, fn func(httpx.WebSocket)) string { + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + sock, err := httpx.NewWebSocket(w, r, 4096, 5) + require.NoError(t, err) + + fn(sock) + })) + return "ws:" + strings.TrimPrefix(s.URL, "http:") +} + +func newSocketConnection(t *testing.T, url string) *websocket.Conn { + d := websocket.Dialer{ + Subprotocols: []string{"p1", "p2"}, + ReadBufferSize: 1024, + WriteBufferSize: 1024, + HandshakeTimeout: 30 * time.Second, + } + c, _, err := d.Dial(url, nil) + assert.NoError(t, err) + return c +} + +func TestSocketMessages(t *testing.T) { + var sock httpx.WebSocket var serverReceived [][]byte var serverCloseCode int - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - sock, err = httpx.NewWebSocket(w, r, 4096, 5) - + serverURL := newSocketServer(t, func(ws httpx.WebSocket) { + sock = ws sock.OnMessage(func(b []byte) { serverReceived = append(serverReceived, b) }) sock.OnClose(func(code int) { serverCloseCode = code }) - sock.Start() + }) - require.NoError(t, err) - })) - - wsURL := "ws:" + strings.TrimPrefix(server.URL, "http:") - - d := websocket.Dialer{ - Subprotocols: []string{"p1", "p2"}, - ReadBufferSize: 1024, - WriteBufferSize: 1024, - HandshakeTimeout: 30 * time.Second, - } - conn, _, err := d.Dial(wsURL, nil) - assert.NoError(t, err) - assert.NotNil(t, conn) + conn := newSocketConnection(t, serverURL) // send a message from the server... sock.Send([]byte("from server")) @@ -60,10 +67,56 @@ func TestSocket(t *testing.T) { conn.WriteMessage(websocket.TextMessage, []byte("to server")) // and check server received it - time.Sleep(time.Second) + time.Sleep(500 * time.Millisecond) assert.Equal(t, [][]byte{[]byte("to server")}, serverReceived) - sock.Close() + var connCloseCode int + conn.SetCloseHandler(func(code int, text string) error { + connCloseCode = code + return nil + }) + + sock.Close(1001) + + conn.ReadMessage() // read the close message + + assert.Equal(t, 1001, serverCloseCode) + assert.Equal(t, 1001, connCloseCode) +} + +func TestSocketClientCloseWithMessage(t *testing.T) { + var sock httpx.WebSocket + var serverCloseCode int + + serverURL := newSocketServer(t, func(ws httpx.WebSocket) { + sock = ws + sock.OnClose(func(code int) { serverCloseCode = code }) + sock.Start() + }) + + conn := newSocketConnection(t, serverURL) + conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.ClosePolicyViolation, "")) + conn.Close() + + time.Sleep(250 * time.Millisecond) + + assert.Equal(t, websocket.ClosePolicyViolation, serverCloseCode) +} + +func TestSocketClientCloseWithoutMessage(t *testing.T) { + var sock httpx.WebSocket + var serverCloseCode int + + serverURL := newSocketServer(t, func(ws httpx.WebSocket) { + sock = ws + sock.OnClose(func(code int) { serverCloseCode = code }) + sock.Start() + }) + + conn := newSocketConnection(t, serverURL) + conn.Close() + + time.Sleep(250 * time.Millisecond) - assert.Equal(t, 1000, serverCloseCode) + assert.Equal(t, websocket.CloseAbnormalClosure, serverCloseCode) }