Skip to content

Commit

Permalink
feat: add websocket watcher (#190)
Browse files Browse the repository at this point in the history
  • Loading branch information
zepatrik authored Aug 24, 2020
1 parent ddfbeed commit a9790e7
Show file tree
Hide file tree
Showing 5 changed files with 445 additions and 2 deletions.
2 changes: 2 additions & 0 deletions watcherx/definitions.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ func Watch(ctx context.Context, u *url.URL, c EventChannel) error {
switch u.Scheme {
case "file":
return WatchFile(ctx, u.Path, c)
case "ws":
return WatchWebsocket(ctx, u, c)
}
return &errSchemeUnknown{u.Scheme}
}
85 changes: 83 additions & 2 deletions watcherx/event.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,19 @@
package watcherx

import "io"
import (
"bytes"
"encoding/json"
"io"
"io/ioutil"

"github.com/pkg/errors"
)

type (
Event interface {
Reader() io.Reader
Source() string
setSource(string)
}
source string
ErrorEvent struct {
Expand All @@ -19,20 +27,93 @@ type (
RemoveEvent struct {
source
}
serialEventType string
serialEvent struct {
Type serialEventType `json:"type"`
Data []byte `json:"data"`
Source source `json:"source"`
}
)

const (
serialTypeChange serialEventType = "change"
serialTypeRemove serialEventType = "remove"
serialTypeError serialEventType = "error"
)

var errUnknownEvent = errors.New("unknown event type")

func (e *ErrorEvent) Reader() io.Reader {
return nil
return bytes.NewBufferString(e.Error())
}

func (e *ErrorEvent) MarshalJSON() ([]byte, error) {
return json.Marshal(serialEvent{
Type: serialTypeError,
Data: []byte(e.Error()),
Source: e.source,
})
}

func (e source) Source() string {
return string(e)
}

func (e *source) setSource(nsrc string) {
*e = source(nsrc)
}

func (e *ChangeEvent) Reader() io.Reader {
return e.data
}

func (e *ChangeEvent) MarshalJSON() ([]byte, error) {
var data []byte
var err error
if e.data != nil {
data, err = ioutil.ReadAll(e.data)
}
if err != nil {
return nil, errors.WithStack(err)
}
return json.Marshal(serialEvent{
Type: serialTypeChange,
Data: data,
Source: e.source,
})
}

func (e *RemoveEvent) Reader() io.Reader {
return nil
}

func (e *RemoveEvent) MarshalJSON() ([]byte, error) {
return json.Marshal(serialEvent{
Type: serialTypeRemove,
Source: e.source,
})
}

func unmarshalEvent(data []byte) (Event, error) {
var serialEvent serialEvent
if err := json.Unmarshal(data, &serialEvent); err != nil {
return nil, errors.WithStack(err)
}
switch serialEvent.Type {
case serialTypeRemove:
return &RemoveEvent{
source: serialEvent.Source,
}, nil
case serialTypeChange:
return &ChangeEvent{
data: bytes.NewBuffer(serialEvent.Data),
source: serialEvent.Source,
}, nil
case serialTypeError:
return &ErrorEvent{
error: errors.New(string(serialEvent.Data)),
source: serialEvent.Source,
}, nil
}
return nil, errUnknownEvent
}
80 changes: 80 additions & 0 deletions watcherx/websocket_client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
package watcherx

import (
"context"
"net"
"net/url"
"strings"

"github.com/gorilla/websocket"
"github.com/pkg/errors"
)

func WatchWebsocket(ctx context.Context, u *url.URL, c EventChannel) error {
conn, _, err := websocket.DefaultDialer.Dial(u.String(), nil)
if err != nil {
return errors.WithStack(err)
}

wsClosed := make(chan struct{})
go cleanupOnDone(ctx, conn, c, wsClosed)

go forwardWebsocketEvents(conn, c, u, wsClosed)

return nil
}

func cleanupOnDone(ctx context.Context, conn *websocket.Conn, c EventChannel, wsClosed <-chan struct{}) {
// wait for one of the events to occur
select {
case <-ctx.Done():
case <-wsClosed:
}

// clean up channel
close(c)
// attempt to close the websocket
// ignore errors as we are closing everything anyway
_ = conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "context canceled by server"))
_ = conn.Close()
}

func forwardWebsocketEvents(ws *websocket.Conn, c EventChannel, u *url.URL, wsClosed chan<- struct{}) {
serverURL := source(u.String())

defer func() {
// this triggers the cleanupOnDone subroutine
close(wsClosed)
}()

for {
// receive messages, this call is blocking
_, msg, err := ws.ReadMessage()
if err != nil {
if closeErr, ok := err.(*websocket.CloseError); ok && closeErr.Code == websocket.CloseNormalClosure {
return
}
// assuming the connection got closed through context canceling
if opErr, ok := err.(*net.OpError); ok && opErr.Op == "read" && strings.Contains(opErr.Err.Error(), "closed") {
return
}
c <- &ErrorEvent{
error: errors.WithStack(err),
source: serverURL,
}
return
}
e, err := unmarshalEvent(msg)
if err != nil {
c <- &ErrorEvent{
error: err,
source: serverURL,
}
continue
}
localURL := *u
localURL.Path = e.Source()
e.setSource(localURL.String())
c <- e
}
}
121 changes: 121 additions & 0 deletions watcherx/websocket_server.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
package watcherx

import (
"context"
"net"
"net/http"
"net/url"
"strings"
"sync"

"github.com/gorilla/websocket"

"github.com/ory/herodot"
)

type eventChannelSlice struct {
sync.Mutex
cs []EventChannel
}

var wsClientChannels = eventChannelSlice{}

func WatchAndServeWS(ctx context.Context, u *url.URL, writer herodot.Writer) (http.HandlerFunc, error) {
c := make(EventChannel)
if err := Watch(ctx, u, c); err != nil {
return nil, err
}
go broadcaster(ctx, c)
return serveWS(ctx, writer), nil
}

func broadcaster(ctx context.Context, c EventChannel) {
defer func() {
close(c)
}()

for {
select {
case <-ctx.Done():
return
case e := <-c:
wsClientChannels.Lock()
for _, cc := range wsClientChannels.cs {
cc <- e
}
wsClientChannels.Unlock()
}
}
}

func notifyOnClose(ws *websocket.Conn, c chan<- struct{}) {
for {
// blocking call to ReadMessage that waits for a close message
_, _, err := ws.ReadMessage()
if closeErr, ok := err.(*websocket.CloseError); ok && closeErr.Code == websocket.CloseNormalClosure {
close(c)
return
}
if opErr, ok := err.(*net.OpError); ok && opErr.Op == "read" && strings.Contains(opErr.Err.Error(), "closed") {
// the context got canceled and therefore the connection closed
close(c)
return
}
}
}

func serveWS(ctx context.Context, writer herodot.Writer) func(w http.ResponseWriter, r *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
ws, err := (&websocket.Upgrader{
ReadBufferSize: 256, // the only message we expect is the close message
WriteBufferSize: 1024,
}).Upgrade(w, r, nil)
if err != nil {
writer.WriteError(w, r, err)
return
}

// make channel and register it at broadcaster
c := make(EventChannel)
wsClientChannels.Lock()
wsClientChannels.cs = append(wsClientChannels.cs, c)
wsClientChannels.Unlock()

wsClosed := make(chan struct{})
go notifyOnClose(ws, wsClosed)

defer func() {
// attempt to close the websocket
// ignore errors as we are closing everything anyway
_ = ws.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "server context canceled"))
_ = ws.Close()

wsClientChannels.Lock()
for i, cc := range wsClientChannels.cs {
if c == cc {
wsClientChannels.cs[i] = wsClientChannels.cs[len(wsClientChannels.cs)-1]
wsClientChannels.cs[len(wsClientChannels.cs)-1] = nil
wsClientChannels.cs = wsClientChannels.cs[:len(wsClientChannels.cs)-1]
}
}
wsClientChannels.Unlock()
close(c)
}()

for {
select {
case <-ctx.Done():
return
case <-wsClosed:
return
case e, ok := <-c:
if !ok {
return
}
if err := ws.WriteJSON(e); err != nil {
return
}
}
}
}
}
Loading

0 comments on commit a9790e7

Please sign in to comment.