diff --git a/watcherx/definitions.go b/watcherx/definitions.go index 8c29773d..f0f0d1db 100644 --- a/watcherx/definitions.go +++ b/watcherx/definitions.go @@ -29,6 +29,8 @@ func Watch(ctx context.Context, u *url.URL, c EventChannel) error { switch u.Scheme { case "file": return WatchFile(ctx, u.Path, c) + case "ws": + return WatchWebsocket(ctx, u, c) } return &errSchemeUnknown{u.Scheme} } diff --git a/watcherx/event.go b/watcherx/event.go index 52df6c43..4a308cc1 100644 --- a/watcherx/event.go +++ b/watcherx/event.go @@ -1,11 +1,19 @@ package watcherx -import "io" +import ( + "bytes" + "encoding/json" + "io" + "io/ioutil" + + "github.com/pkg/errors" +) type ( Event interface { Reader() io.Reader Source() string + setSource(string) } source string ErrorEvent struct { @@ -19,20 +27,93 @@ type ( RemoveEvent struct { source } + serialEventType string + serialEvent struct { + Type serialEventType `json:"type"` + Data []byte `json:"data"` + Source source `json:"source"` + } ) +const ( + serialTypeChange serialEventType = "change" + serialTypeRemove serialEventType = "remove" + serialTypeError serialEventType = "error" +) + +var errUnknownEvent = errors.New("unknown event type") + func (e *ErrorEvent) Reader() io.Reader { - return nil + return bytes.NewBufferString(e.Error()) +} + +func (e *ErrorEvent) MarshalJSON() ([]byte, error) { + return json.Marshal(serialEvent{ + Type: serialTypeError, + Data: []byte(e.Error()), + Source: e.source, + }) } func (e source) Source() string { return string(e) } +func (e *source) setSource(nsrc string) { + *e = source(nsrc) +} + func (e *ChangeEvent) Reader() io.Reader { return e.data } +func (e *ChangeEvent) MarshalJSON() ([]byte, error) { + var data []byte + var err error + if e.data != nil { + data, err = ioutil.ReadAll(e.data) + } + if err != nil { + return nil, errors.WithStack(err) + } + return json.Marshal(serialEvent{ + Type: serialTypeChange, + Data: data, + Source: e.source, + }) +} + func (e *RemoveEvent) Reader() io.Reader { return nil } + +func (e *RemoveEvent) MarshalJSON() ([]byte, error) { + return json.Marshal(serialEvent{ + Type: serialTypeRemove, + Source: e.source, + }) +} + +func unmarshalEvent(data []byte) (Event, error) { + var serialEvent serialEvent + if err := json.Unmarshal(data, &serialEvent); err != nil { + return nil, errors.WithStack(err) + } + switch serialEvent.Type { + case serialTypeRemove: + return &RemoveEvent{ + source: serialEvent.Source, + }, nil + case serialTypeChange: + return &ChangeEvent{ + data: bytes.NewBuffer(serialEvent.Data), + source: serialEvent.Source, + }, nil + case serialTypeError: + return &ErrorEvent{ + error: errors.New(string(serialEvent.Data)), + source: serialEvent.Source, + }, nil + } + return nil, errUnknownEvent +} diff --git a/watcherx/websocket_client.go b/watcherx/websocket_client.go new file mode 100644 index 00000000..f1f2da3a --- /dev/null +++ b/watcherx/websocket_client.go @@ -0,0 +1,80 @@ +package watcherx + +import ( + "context" + "net" + "net/url" + "strings" + + "github.com/gorilla/websocket" + "github.com/pkg/errors" +) + +func WatchWebsocket(ctx context.Context, u *url.URL, c EventChannel) error { + conn, _, err := websocket.DefaultDialer.Dial(u.String(), nil) + if err != nil { + return errors.WithStack(err) + } + + wsClosed := make(chan struct{}) + go cleanupOnDone(ctx, conn, c, wsClosed) + + go forwardWebsocketEvents(conn, c, u, wsClosed) + + return nil +} + +func cleanupOnDone(ctx context.Context, conn *websocket.Conn, c EventChannel, wsClosed <-chan struct{}) { + // wait for one of the events to occur + select { + case <-ctx.Done(): + case <-wsClosed: + } + + // clean up channel + close(c) + // attempt to close the websocket + // ignore errors as we are closing everything anyway + _ = conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "context canceled by server")) + _ = conn.Close() +} + +func forwardWebsocketEvents(ws *websocket.Conn, c EventChannel, u *url.URL, wsClosed chan<- struct{}) { + serverURL := source(u.String()) + + defer func() { + // this triggers the cleanupOnDone subroutine + close(wsClosed) + }() + + for { + // receive messages, this call is blocking + _, msg, err := ws.ReadMessage() + if err != nil { + if closeErr, ok := err.(*websocket.CloseError); ok && closeErr.Code == websocket.CloseNormalClosure { + return + } + // assuming the connection got closed through context canceling + if opErr, ok := err.(*net.OpError); ok && opErr.Op == "read" && strings.Contains(opErr.Err.Error(), "closed") { + return + } + c <- &ErrorEvent{ + error: errors.WithStack(err), + source: serverURL, + } + return + } + e, err := unmarshalEvent(msg) + if err != nil { + c <- &ErrorEvent{ + error: err, + source: serverURL, + } + continue + } + localURL := *u + localURL.Path = e.Source() + e.setSource(localURL.String()) + c <- e + } +} diff --git a/watcherx/websocket_server.go b/watcherx/websocket_server.go new file mode 100644 index 00000000..5673b3d7 --- /dev/null +++ b/watcherx/websocket_server.go @@ -0,0 +1,121 @@ +package watcherx + +import ( + "context" + "net" + "net/http" + "net/url" + "strings" + "sync" + + "github.com/gorilla/websocket" + + "github.com/ory/herodot" +) + +type eventChannelSlice struct { + sync.Mutex + cs []EventChannel +} + +var wsClientChannels = eventChannelSlice{} + +func WatchAndServeWS(ctx context.Context, u *url.URL, writer herodot.Writer) (http.HandlerFunc, error) { + c := make(EventChannel) + if err := Watch(ctx, u, c); err != nil { + return nil, err + } + go broadcaster(ctx, c) + return serveWS(ctx, writer), nil +} + +func broadcaster(ctx context.Context, c EventChannel) { + defer func() { + close(c) + }() + + for { + select { + case <-ctx.Done(): + return + case e := <-c: + wsClientChannels.Lock() + for _, cc := range wsClientChannels.cs { + cc <- e + } + wsClientChannels.Unlock() + } + } +} + +func notifyOnClose(ws *websocket.Conn, c chan<- struct{}) { + for { + // blocking call to ReadMessage that waits for a close message + _, _, err := ws.ReadMessage() + if closeErr, ok := err.(*websocket.CloseError); ok && closeErr.Code == websocket.CloseNormalClosure { + close(c) + return + } + if opErr, ok := err.(*net.OpError); ok && opErr.Op == "read" && strings.Contains(opErr.Err.Error(), "closed") { + // the context got canceled and therefore the connection closed + close(c) + return + } + } +} + +func serveWS(ctx context.Context, writer herodot.Writer) func(w http.ResponseWriter, r *http.Request) { + return func(w http.ResponseWriter, r *http.Request) { + ws, err := (&websocket.Upgrader{ + ReadBufferSize: 256, // the only message we expect is the close message + WriteBufferSize: 1024, + }).Upgrade(w, r, nil) + if err != nil { + writer.WriteError(w, r, err) + return + } + + // make channel and register it at broadcaster + c := make(EventChannel) + wsClientChannels.Lock() + wsClientChannels.cs = append(wsClientChannels.cs, c) + wsClientChannels.Unlock() + + wsClosed := make(chan struct{}) + go notifyOnClose(ws, wsClosed) + + defer func() { + // attempt to close the websocket + // ignore errors as we are closing everything anyway + _ = ws.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "server context canceled")) + _ = ws.Close() + + wsClientChannels.Lock() + for i, cc := range wsClientChannels.cs { + if c == cc { + wsClientChannels.cs[i] = wsClientChannels.cs[len(wsClientChannels.cs)-1] + wsClientChannels.cs[len(wsClientChannels.cs)-1] = nil + wsClientChannels.cs = wsClientChannels.cs[:len(wsClientChannels.cs)-1] + } + } + wsClientChannels.Unlock() + close(c) + }() + + for { + select { + case <-ctx.Done(): + return + case <-wsClosed: + return + case e, ok := <-c: + if !ok { + return + } + if err := ws.WriteJSON(e); err != nil { + return + } + } + } + } +} diff --git a/watcherx/websocket_test.go b/watcherx/websocket_test.go new file mode 100644 index 00000000..daef914f --- /dev/null +++ b/watcherx/websocket_test.go @@ -0,0 +1,159 @@ +package watcherx + +import ( + "context" + "fmt" + "net/http/httptest" + "os" + "path" + "strings" + "testing" + + "github.com/sirupsen/logrus/hooks/test" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/ory/herodot" + "github.com/ory/x/urlx" +) + +func TestWatchWebsocket(t *testing.T) { + t.Run("case=forwards events", func(t *testing.T) { + ctx, c, dir, cancel := setup(t) + defer cancel() + l, hook := test.NewNullLogger() + + fn := path.Join(dir, "some.file") + f, err := os.Create(fn) + require.NoError(t, err) + + handler, err := WatchAndServeWS(ctx, urlx.ParseOrPanic("file://"+fn), herodot.NewJSONWriter(l)) + require.NoError(t, err) + s := httptest.NewServer(handler) + + u := urlx.ParseOrPanic("ws" + strings.TrimLeft(s.URL, "http")) + require.NoError(t, WatchWebsocket(ctx, u, c)) + + _, err = fmt.Fprint(f, "content here") + require.NoError(t, err) + require.NoError(t, f.Close()) + assertChange(t, <-c, "content here", u.String()+fn) + + require.NoError(t, os.Remove(fn)) + assertRemove(t, <-c, u.String()+fn) + + assert.Len(t, hook.Entries, 0, "%+v", hook.Entries) + }) + + t.Run("case=client closes itself on context cancel", func(t *testing.T) { + ctx1, c, dir, cancel1 := setup(t) + defer cancel1() + l, hook := test.NewNullLogger() + + fn := path.Join(dir, "some.file") + + handler, err := WatchAndServeWS(ctx1, urlx.ParseOrPanic("file://"+fn), herodot.NewJSONWriter(l)) + require.NoError(t, err) + s := httptest.NewServer(handler) + + ctx2, cancel2 := context.WithCancel(context.Background()) + u := urlx.ParseOrPanic("ws" + strings.TrimLeft(s.URL, "http")) + require.NoError(t, WatchWebsocket(ctx2, u, c)) + + cancel2() + + e, ok := <-c + assert.False(t, ok, "%#v", e) + + assert.Len(t, hook.Entries, 0, "%+v", hook.Entries) + }) + + t.Run("case=quits client watcher when server connection is closed", func(t *testing.T) { + ctxClient, c, dir, cancel := setup(t) + defer cancel() + l, hook := test.NewNullLogger() + + fn := path.Join(dir, "some.file") + + ctxServe, cancelServe := context.WithCancel(context.Background()) + handler, err := WatchAndServeWS(ctxServe, urlx.ParseOrPanic("file://"+fn), herodot.NewJSONWriter(l)) + require.NoError(t, err) + s := httptest.NewServer(handler) + + u := urlx.ParseOrPanic("ws" + strings.TrimLeft(s.URL, "http")) + require.NoError(t, WatchWebsocket(ctxClient, u, c)) + + cancelServe() + + e, ok := <-c + assert.False(t, ok, "%#v", e) + + assert.Len(t, hook.Entries, 0, "%+v", hook.Entries) + }) + + t.Run("case=successive watching works after client connection is closed", func(t *testing.T) { + ctxServer, c, dir, cancel := setup(t) + defer cancel() + l, hook := test.NewNullLogger() + + fn := path.Join(dir, "some.file") + + handler, err := WatchAndServeWS(ctxServer, urlx.ParseOrPanic("file://"+fn), herodot.NewJSONWriter(l)) + require.NoError(t, err) + s := httptest.NewServer(handler) + + ctxClient1, cancelClient1 := context.WithCancel(context.Background()) + u := urlx.ParseOrPanic("ws" + strings.TrimLeft(s.URL, "http")) + require.NoError(t, WatchWebsocket(ctxClient1, u, c)) + + cancelClient1() + + _, ok := <-c + assert.False(t, ok) + + ctxClient2, cancelClient2 := context.WithCancel(context.Background()) + defer cancelClient2() + c2 := make(EventChannel) + require.NoError(t, WatchWebsocket(ctxClient2, u, c2)) + + f, err := os.Create(fn) + require.NoError(t, err) + require.NoError(t, f.Close()) + + assertChange(t, <-c2, "", u.String()+fn) + + assert.Len(t, hook.Entries, 0, "%+v", hook.Entries) + }) + + t.Run("case=broadcasts to multiple client connections", func(t *testing.T) { + ctxServer, c1, dir, cancel := setup(t) + defer cancel() + l, hook := test.NewNullLogger() + + fn := path.Join(dir, "some.file") + + handler, err := WatchAndServeWS(ctxServer, urlx.ParseOrPanic("file://"+fn), herodot.NewJSONWriter(l)) + require.NoError(t, err) + s := httptest.NewServer(handler) + + ctxClient1, cancelClient1 := context.WithCancel(context.Background()) + defer cancelClient1() + + u := urlx.ParseOrPanic("ws" + strings.TrimLeft(s.URL, "http")) + require.NoError(t, WatchWebsocket(ctxClient1, u, c1)) + + ctxClient2, cancelClient2 := context.WithCancel(context.Background()) + defer cancelClient2() + c2 := make(EventChannel) + require.NoError(t, WatchWebsocket(ctxClient2, u, c2)) + + f, err := os.Create(fn) + require.NoError(t, err) + require.NoError(t, f.Close()) + + assertChange(t, <-c1, "", u.String()+fn) + assertChange(t, <-c2, "", u.String()+fn) + + assert.Len(t, hook.Entries, 0, "%+v", hook.Entries) + }) +}