Skip to content

Commit

Permalink
Add websocket struct to httpx
Browse files Browse the repository at this point in the history
  • Loading branch information
rowanseymour committed Jan 11, 2024
1 parent ac4221c commit f7a36fb
Show file tree
Hide file tree
Showing 4 changed files with 246 additions and 0 deletions.
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/gorilla/websocket v1.5.1
github.com/jmespath/go-jmespath v0.4.0 // indirect
github.com/leodido/go-urn v1.2.4 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaS
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.8 h1:e6P7q2lk1O+qJJb4BtCQXlK8vWEO8V1ZeuEdJNOqZyg=
github.com/google/go-cmp v0.5.8/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/gorilla/websocket v1.5.1 h1:gmztn0JnHVt9JZquRuzLw3g4wouNVzKL15iLr/zn/QY=
github.com/gorilla/websocket v1.5.1/go.mod h1:x3kM2JMyaluk02fnUJpQuwD2dCS5NDG2ZHL0uE0tcaY=
github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9YPoQUg=
github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo=
github.com/jmespath/go-jmespath/internal/testify v1.5.1 h1:shLQSRRSCCPj3f2gpwzGwWFoC7ycTf1rcQZHOlsJ6N8=
Expand Down
174 changes: 174 additions & 0 deletions httpx/websocket.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
package httpx

import (
"net/http"
"sync"
"time"

"github.com/gorilla/websocket"
)

const (
// max time for between reading a message before socket is considered closed
maxReadWait = 60 * time.Second

// maximum time to wait for message to be written
maxWriteWait = 15 * time.Second

// how often to send a ping message
pingPeriod = 30 * time.Second
)

var upgrader = websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
}

type WebSocket interface {
Start()
Send(msg []byte)
Close()

OnMessage(fn func([]byte))
OnClose(fn func(int))
}

// Socket implemention using gorilla library
type socket struct {
conn *websocket.Conn
onMessage func([]byte)
onClose func(int)
outbox chan []byte
readError chan error
writeError chan error
stopWriter chan bool
stopMonitor chan bool
rwWaitGroup sync.WaitGroup
monitorWaitGroup sync.WaitGroup
}

func NewWebSocket(w http.ResponseWriter, r *http.Request, maxReadBytes int64, sendBuffer int) (WebSocket, error) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
return nil, err
}

Check warning on line 54 in httpx/websocket.go

View check run for this annotation

Codecov / codecov/patch

httpx/websocket.go#L53-L54

Added lines #L53 - L54 were not covered by tests

conn.SetReadLimit(maxReadBytes)

return &socket{
conn: conn,
onMessage: func([]byte) {},
onClose: func(int) {},

Check warning on line 61 in httpx/websocket.go

View check run for this annotation

Codecov / codecov/patch

httpx/websocket.go#L61

Added line #L61 was not covered by tests
outbox: make(chan []byte, sendBuffer),
readError: make(chan error, 1),
writeError: make(chan error, 1),
stopWriter: make(chan bool, 1),
stopMonitor: make(chan bool, 1),
}, nil
}

func (s *socket) OnMessage(fn func([]byte)) { s.onMessage = fn }
func (s *socket) OnClose(fn func(int)) { s.onClose = fn }

func (s *socket) Start() {
s.conn.SetReadDeadline(time.Now().Add(maxReadWait))
s.conn.SetPongHandler(s.pong)

go s.monitor()
go s.reader()
go s.writer()
}

func (s *socket) Send(msg []byte) {
s.outbox <- msg
}

func (s *socket) Close() {
s.conn.Close() // causes reader to stop
s.stopWriter <- true
s.stopMonitor <- true

s.monitorWaitGroup.Wait()
}

func (s *socket) pong(m string) error {
s.conn.SetReadDeadline(time.Now().Add(maxReadWait))

return nil

Check warning on line 97 in httpx/websocket.go

View check run for this annotation

Codecov / codecov/patch

httpx/websocket.go#L94-L97

Added lines #L94 - L97 were not covered by tests
}

func (s *socket) monitor() {
s.monitorWaitGroup.Add(1)
defer s.monitorWaitGroup.Done()

closeCode := websocket.CloseNormalClosure

out:
for {
select {
case err := <-s.readError:
if e, ok := err.(*websocket.CloseError); ok {
closeCode = e.Code
}

Check warning on line 112 in httpx/websocket.go

View check run for this annotation

Codecov / codecov/patch

httpx/websocket.go#L111-L112

Added lines #L111 - L112 were not covered by tests
s.stopWriter <- true // ensure writer is stopped
break out
case err := <-s.writeError:
if e, ok := err.(*websocket.CloseError); ok {
closeCode = e.Code
}
s.conn.Close() // ensure reader is stopped
break out
case <-s.stopMonitor:
break out

Check warning on line 122 in httpx/websocket.go

View check run for this annotation

Codecov / codecov/patch

httpx/websocket.go#L115-L122

Added lines #L115 - L122 were not covered by tests
}
}

s.rwWaitGroup.Wait()

s.onClose(closeCode)
}

func (s *socket) reader() {
s.rwWaitGroup.Add(1)
defer s.rwWaitGroup.Done()

for {
_, message, err := s.conn.ReadMessage()
if err != nil {
s.readError <- err
return
}

s.onMessage(message)
}
}

func (s *socket) writer() {
s.rwWaitGroup.Add(1)
defer s.rwWaitGroup.Done()

ticker := time.NewTicker(pingPeriod)
defer ticker.Stop()

for {
select {
case msg := <-s.outbox:
s.conn.SetWriteDeadline(time.Now().Add(maxWriteWait))

err := s.conn.WriteMessage(websocket.TextMessage, msg)
if err != nil {
s.writeError <- err
return
}
case <-ticker.C:
s.conn.SetWriteDeadline(time.Now().Add(maxWriteWait))

if err := s.conn.WriteMessage(websocket.PingMessage, nil); err != nil {
s.writeError <- err
return
}

Check warning on line 169 in httpx/websocket.go

View check run for this annotation

Codecov / codecov/patch

httpx/websocket.go#L160-L169

Added lines #L160 - L169 were not covered by tests
case <-s.stopWriter:
return
}
}
}
69 changes: 69 additions & 0 deletions httpx/websocket_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
package httpx_test

import (
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"

"github.com/gorilla/websocket"
"github.com/nyaruka/gocommon/httpx"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestSocket(t *testing.T) {
var sock httpx.WebSocket
var err error

var serverReceived [][]byte
var serverCloseCode int

server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
sock, err = httpx.NewWebSocket(w, r, 4096, 5)

sock.OnMessage(func(b []byte) {
serverReceived = append(serverReceived, b)
})
sock.OnClose(func(code int) {
serverCloseCode = code
})

sock.Start()

require.NoError(t, err)
}))

wsURL := "ws:" + strings.TrimPrefix(server.URL, "http:")

d := websocket.Dialer{
Subprotocols: []string{"p1", "p2"},
ReadBufferSize: 1024,
WriteBufferSize: 1024,
HandshakeTimeout: 30 * time.Second,
}
conn, _, err := d.Dial(wsURL, nil)
assert.NoError(t, err)
assert.NotNil(t, conn)

// send a message from the server...
sock.Send([]byte("from server"))

// and read it from the client
msgType, msg, err := conn.ReadMessage()
assert.NoError(t, err)
assert.Equal(t, 1, msgType)
assert.Equal(t, "from server", string(msg))

// send a message from the client...
conn.WriteMessage(websocket.TextMessage, []byte("to server"))

// and check server received it
time.Sleep(time.Second)
assert.Equal(t, [][]byte{[]byte("to server")}, serverReceived)

sock.Close()

assert.Equal(t, 1000, serverCloseCode)
}

0 comments on commit f7a36fb

Please sign in to comment.