Skip to content

Commit

Permalink
feat: add unsubscribe
Browse files Browse the repository at this point in the history
  • Loading branch information
ViolaPioggia committed Jan 17, 2024
1 parent 688b860 commit 003d4f7
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 10 deletions.
2 changes: 1 addition & 1 deletion README_CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ func main() {
}
}()

select {}
wg.Wait()
}

```
Expand Down
2 changes: 1 addition & 1 deletion client.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ func (c *Client) startReadLoop(ctx context.Context, reader *EventStreamReader) (
func (c *Client) readLoop(ctx context.Context, reader *EventStreamReader, outCh chan *Event, erChan chan error) {
for {
// Read each new line and process the type of event
event, err := reader.ReadEvent()
event, err := reader.ReadEvent(ctx)
if err != nil {
if err == io.EOF {
erChan <- nil
Expand Down
26 changes: 26 additions & 0 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,32 @@ func TestClientSubscribe(t *testing.T) {
assert.Nil(t, cErr)
}

func TestClientUnSubscribe(t *testing.T) {
go newServer(false, "8887")
time.Sleep(time.Second)
c := NewClient("http://127.0.0.1:8887/sse")

events := make(chan *Event)
ctx, cancel := context.WithCancel(context.Background())
var cErr error
go func() {
cErr = c.SubscribeWithContext(ctx, func(msg *Event) {
if msg.Data != nil {
events <- msg
return
}
})
}()
cancel()
time.Sleep(5 * time.Second)
for i := 0; i < 5; i++ {
_, err := wait(events, time.Second*1)
assert.DeepEqual(t, errors.New("timeout"), err)
}

assert.Nil(t, cErr)
}

func TestClientSubscribeMultiline(t *testing.T) {
go newMultilineServer("9007")
time.Sleep(time.Second)
Expand Down
12 changes: 9 additions & 3 deletions event.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ package sse
import (
"bufio"
"bytes"
"context"
"io"
)

Expand Down Expand Up @@ -119,10 +120,15 @@ func minPosInt(a, b int) int {
}

// ReadEvent scans the EventStream for events.
func (e *EventStreamReader) ReadEvent() ([]byte, error) {
func (e *EventStreamReader) ReadEvent(ctx context.Context) ([]byte, error) {
if e.scanner.Scan() {
event := e.scanner.Bytes()
return event, nil
select {
case <-ctx.Done():
return nil, io.EOF
default:
event := e.scanner.Bytes()
return event, nil
}
}
if err := e.scanner.Err(); err != nil {
return nil, err
Expand Down
16 changes: 11 additions & 5 deletions examples/client/quickstart/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,11 @@ package main

import (
"context"
"sync"

"github.com/hertz-contrib/sse"

"fmt"
"github.com/cloudwego/hertz/pkg/common/hlog"
"github.com/hertz-contrib/sse"
"sync"
"time"
)

var wg sync.WaitGroup
Expand All @@ -64,15 +64,21 @@ func main() {

events := make(chan *sse.Event)
errChan := make(chan error)
ctx, cancel := context.WithCancel(context.Background())
go func() {
cErr := c.Subscribe(func(msg *sse.Event) {
cErr := c.SubscribeWithContext(ctx, func(msg *sse.Event) {
if msg.Data != nil {
events <- msg
return
}
})
errChan <- cErr
}()
go func() {
time.Sleep(5 * time.Second)
cancel()
fmt.Println("client1 subscribe cancel")
}()
for {
select {
case e := <-events:
Expand Down

0 comments on commit 003d4f7

Please sign in to comment.