diff --git a/README_CN.md b/README_CN.md index 1fb14ac..5634352 100644 --- a/README_CN.md +++ b/README_CN.md @@ -156,7 +156,7 @@ func main() { } }() - select {} + wg.Wait() } ``` diff --git a/client.go b/client.go index 1e5b2c8..13077be 100644 --- a/client.go +++ b/client.go @@ -124,7 +124,7 @@ func (c *Client) startReadLoop(ctx context.Context, reader *EventStreamReader) ( func (c *Client) readLoop(ctx context.Context, reader *EventStreamReader, outCh chan *Event, erChan chan error) { for { // Read each new line and process the type of event - event, err := reader.ReadEvent() + event, err := reader.ReadEvent(ctx) if err != nil { if err == io.EOF { erChan <- nil diff --git a/client_test.go b/client_test.go index c2df161..8279796 100644 --- a/client_test.go +++ b/client_test.go @@ -197,6 +197,32 @@ func TestClientSubscribe(t *testing.T) { assert.Nil(t, cErr) } +func TestClientUnSubscribe(t *testing.T) { + go newServer(false, "8887") + time.Sleep(time.Second) + c := NewClient("http://127.0.0.1:8887/sse") + + events := make(chan *Event) + ctx, cancel := context.WithCancel(context.Background()) + var cErr error + go func() { + cErr = c.SubscribeWithContext(ctx, func(msg *Event) { + if msg.Data != nil { + events <- msg + return + } + }) + }() + cancel() + time.Sleep(5 * time.Second) + for i := 0; i < 5; i++ { + _, err := wait(events, time.Second*1) + assert.DeepEqual(t, errors.New("timeout"), err) + } + + assert.Nil(t, cErr) +} + func TestClientSubscribeMultiline(t *testing.T) { go newMultilineServer("9007") time.Sleep(time.Second) diff --git a/event.go b/event.go index fd0a028..e9b47fb 100644 --- a/event.go +++ b/event.go @@ -40,6 +40,7 @@ package sse import ( "bufio" "bytes" + "context" "io" ) @@ -119,10 +120,15 @@ func minPosInt(a, b int) int { } // ReadEvent scans the EventStream for events. -func (e *EventStreamReader) ReadEvent() ([]byte, error) { +func (e *EventStreamReader) ReadEvent(ctx context.Context) ([]byte, error) { if e.scanner.Scan() { - event := e.scanner.Bytes() - return event, nil + select { + case <-ctx.Done(): + return nil, io.EOF + default: + event := e.scanner.Bytes() + return event, nil + } } if err := e.scanner.Err(); err != nil { return nil, err diff --git a/examples/client/quickstart/main.go b/examples/client/quickstart/main.go index f2de8cb..75c0a5b 100644 --- a/examples/client/quickstart/main.go +++ b/examples/client/quickstart/main.go @@ -38,11 +38,11 @@ package main import ( "context" - "sync" - - "github.com/hertz-contrib/sse" - + "fmt" "github.com/cloudwego/hertz/pkg/common/hlog" + "github.com/hertz-contrib/sse" + "sync" + "time" ) var wg sync.WaitGroup @@ -64,8 +64,9 @@ func main() { events := make(chan *sse.Event) errChan := make(chan error) + ctx, cancel := context.WithCancel(context.Background()) go func() { - cErr := c.Subscribe(func(msg *sse.Event) { + cErr := c.SubscribeWithContext(ctx, func(msg *sse.Event) { if msg.Data != nil { events <- msg return @@ -73,6 +74,11 @@ func main() { }) errChan <- cErr }() + go func() { + time.Sleep(5 * time.Second) + cancel() + fmt.Println("client1 subscribe cancel") + }() for { select { case e := <-events: