From f7a36fbd9ee970bf2cf1295069edfbcffe314648 Mon Sep 17 00:00:00 2001 From: Rowan Seymour Date: Thu, 11 Jan 2024 16:06:03 -0500 Subject: [PATCH] Add websocket struct to httpx --- go.mod | 1 + go.sum | 2 + httpx/websocket.go | 174 ++++++++++++++++++++++++++++++++++++++++ httpx/websocket_test.go | 69 ++++++++++++++++ 4 files changed, 246 insertions(+) create mode 100644 httpx/websocket.go create mode 100644 httpx/websocket_test.go diff --git a/go.mod b/go.mod index 741e606..2ab20a8 100644 --- a/go.mod +++ b/go.mod @@ -25,6 +25,7 @@ require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/go-playground/locales v0.14.1 // indirect github.com/go-playground/universal-translator v0.18.1 // indirect + github.com/gorilla/websocket v1.5.1 github.com/jmespath/go-jmespath v0.4.0 // indirect github.com/leodido/go-urn v1.2.4 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect diff --git a/go.sum b/go.sum index d05d53d..556029a 100644 --- a/go.sum +++ b/go.sum @@ -23,6 +23,8 @@ github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaS github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.8 h1:e6P7q2lk1O+qJJb4BtCQXlK8vWEO8V1ZeuEdJNOqZyg= github.com/google/go-cmp v0.5.8/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/gorilla/websocket v1.5.1 h1:gmztn0JnHVt9JZquRuzLw3g4wouNVzKL15iLr/zn/QY= +github.com/gorilla/websocket v1.5.1/go.mod h1:x3kM2JMyaluk02fnUJpQuwD2dCS5NDG2ZHL0uE0tcaY= github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9YPoQUg= github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo= github.com/jmespath/go-jmespath/internal/testify v1.5.1 h1:shLQSRRSCCPj3f2gpwzGwWFoC7ycTf1rcQZHOlsJ6N8= diff --git a/httpx/websocket.go b/httpx/websocket.go new file mode 100644 index 0000000..147a75a --- /dev/null +++ b/httpx/websocket.go @@ -0,0 +1,174 @@ +package httpx + +import ( + "net/http" + "sync" + "time" + + "github.com/gorilla/websocket" +) + +const ( + // max time for between reading a message before socket is considered closed + maxReadWait = 60 * time.Second + + // maximum time to wait for message to be written + maxWriteWait = 15 * time.Second + + // how often to send a ping message + pingPeriod = 30 * time.Second +) + +var upgrader = websocket.Upgrader{ + ReadBufferSize: 1024, + WriteBufferSize: 1024, +} + +type WebSocket interface { + Start() + Send(msg []byte) + Close() + + OnMessage(fn func([]byte)) + OnClose(fn func(int)) +} + +// Socket implemention using gorilla library +type socket struct { + conn *websocket.Conn + onMessage func([]byte) + onClose func(int) + outbox chan []byte + readError chan error + writeError chan error + stopWriter chan bool + stopMonitor chan bool + rwWaitGroup sync.WaitGroup + monitorWaitGroup sync.WaitGroup +} + +func NewWebSocket(w http.ResponseWriter, r *http.Request, maxReadBytes int64, sendBuffer int) (WebSocket, error) { + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + return nil, err + } + + conn.SetReadLimit(maxReadBytes) + + return &socket{ + conn: conn, + onMessage: func([]byte) {}, + onClose: func(int) {}, + outbox: make(chan []byte, sendBuffer), + readError: make(chan error, 1), + writeError: make(chan error, 1), + stopWriter: make(chan bool, 1), + stopMonitor: make(chan bool, 1), + }, nil +} + +func (s *socket) OnMessage(fn func([]byte)) { s.onMessage = fn } +func (s *socket) OnClose(fn func(int)) { s.onClose = fn } + +func (s *socket) Start() { + s.conn.SetReadDeadline(time.Now().Add(maxReadWait)) + s.conn.SetPongHandler(s.pong) + + go s.monitor() + go s.reader() + go s.writer() +} + +func (s *socket) Send(msg []byte) { + s.outbox <- msg +} + +func (s *socket) Close() { + s.conn.Close() // causes reader to stop + s.stopWriter <- true + s.stopMonitor <- true + + s.monitorWaitGroup.Wait() +} + +func (s *socket) pong(m string) error { + s.conn.SetReadDeadline(time.Now().Add(maxReadWait)) + + return nil +} + +func (s *socket) monitor() { + s.monitorWaitGroup.Add(1) + defer s.monitorWaitGroup.Done() + + closeCode := websocket.CloseNormalClosure + +out: + for { + select { + case err := <-s.readError: + if e, ok := err.(*websocket.CloseError); ok { + closeCode = e.Code + } + s.stopWriter <- true // ensure writer is stopped + break out + case err := <-s.writeError: + if e, ok := err.(*websocket.CloseError); ok { + closeCode = e.Code + } + s.conn.Close() // ensure reader is stopped + break out + case <-s.stopMonitor: + break out + } + } + + s.rwWaitGroup.Wait() + + s.onClose(closeCode) +} + +func (s *socket) reader() { + s.rwWaitGroup.Add(1) + defer s.rwWaitGroup.Done() + + for { + _, message, err := s.conn.ReadMessage() + if err != nil { + s.readError <- err + return + } + + s.onMessage(message) + } +} + +func (s *socket) writer() { + s.rwWaitGroup.Add(1) + defer s.rwWaitGroup.Done() + + ticker := time.NewTicker(pingPeriod) + defer ticker.Stop() + + for { + select { + case msg := <-s.outbox: + s.conn.SetWriteDeadline(time.Now().Add(maxWriteWait)) + + err := s.conn.WriteMessage(websocket.TextMessage, msg) + if err != nil { + s.writeError <- err + return + } + case <-ticker.C: + s.conn.SetWriteDeadline(time.Now().Add(maxWriteWait)) + + if err := s.conn.WriteMessage(websocket.PingMessage, nil); err != nil { + s.writeError <- err + return + } + case <-s.stopWriter: + return + } + } +} diff --git a/httpx/websocket_test.go b/httpx/websocket_test.go new file mode 100644 index 0000000..714ccda --- /dev/null +++ b/httpx/websocket_test.go @@ -0,0 +1,69 @@ +package httpx_test + +import ( + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/gorilla/websocket" + "github.com/nyaruka/gocommon/httpx" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestSocket(t *testing.T) { + var sock httpx.WebSocket + var err error + + var serverReceived [][]byte + var serverCloseCode int + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + sock, err = httpx.NewWebSocket(w, r, 4096, 5) + + sock.OnMessage(func(b []byte) { + serverReceived = append(serverReceived, b) + }) + sock.OnClose(func(code int) { + serverCloseCode = code + }) + + sock.Start() + + require.NoError(t, err) + })) + + wsURL := "ws:" + strings.TrimPrefix(server.URL, "http:") + + d := websocket.Dialer{ + Subprotocols: []string{"p1", "p2"}, + ReadBufferSize: 1024, + WriteBufferSize: 1024, + HandshakeTimeout: 30 * time.Second, + } + conn, _, err := d.Dial(wsURL, nil) + assert.NoError(t, err) + assert.NotNil(t, conn) + + // send a message from the server... + sock.Send([]byte("from server")) + + // and read it from the client + msgType, msg, err := conn.ReadMessage() + assert.NoError(t, err) + assert.Equal(t, 1, msgType) + assert.Equal(t, "from server", string(msg)) + + // send a message from the client... + conn.WriteMessage(websocket.TextMessage, []byte("to server")) + + // and check server received it + time.Sleep(time.Second) + assert.Equal(t, [][]byte{[]byte("to server")}, serverReceived) + + sock.Close() + + assert.Equal(t, 1000, serverCloseCode) +}