Skip to content

Commit

Permalink
chore:unify option (#11)
Browse files Browse the repository at this point in the history
* chore:unify option

* docs:fix a description

* docs:replace validator with callback

* chore:change client struct
  • Loading branch information
ViolaPioggia authored Nov 23, 2023
1 parent 6514a9e commit 6305346
Show file tree
Hide file tree
Showing 5 changed files with 107 additions and 52 deletions.
16 changes: 8 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,13 +89,13 @@ func main() {
c := sse.NewClient("http://127.0.0.1:8888/sse")

// touch off when connected to the server
c.OnConnect(func(ctx context.Context, client *sse.Client) {
hlog.Infof("client1 connect to server %s success with %s method", c.URL, c.Method)
c.SetOnConnectCallback(func(ctx context.Context, client *sse.Client) {
hlog.Infof("client1 connect to server %s success with %s method", c.GetURL(), c.GetMethod())
})

// touch off when the connection is shutdown
c.OnDisconnect(func(ctx context.Context, client *sse.Client) {
hlog.Infof("client1 disconnect to server %s success with %s method", c.URL, c.Method)
c.SetDisconnectCallback(func(ctx context.Context, client *sse.Client) {
hlog.Infof("client1 disconnect to server %s success with %s method", c.GetURL(), c.GetMethod())
})

events := make(chan *sse.Event)
Expand Down Expand Up @@ -126,13 +126,13 @@ func main() {
c := sse.NewClient("http://127.0.0.1:8888/sse")

// touch off when connected to the server
c.OnConnect(func(ctx context.Context, client *sse.Client) {
hlog.Infof("client2 %s connect to server success with %s method", c.URL, c.Method)
c.SetOnConnectCallback(func(ctx context.Context, client *sse.Client) {
hlog.Infof("client2 %s connect to server success with %s method", c.GetURL(), c.GetMethod())
})

// touch off when the connection is shutdown
c.OnDisconnect(func(ctx context.Context, client *sse.Client) {
hlog.Infof("client2 %s disconnect to server success with %s method", c.URL, c.Method)
c.SetDisconnectCallback(func(ctx context.Context, client *sse.Client) {
hlog.Infof("client2 %s disconnect to server success with %s method", c.GetURL(), c.GetMethod())
})

events := make(chan *sse.Event)
Expand Down
16 changes: 8 additions & 8 deletions README_CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,13 +87,13 @@ func main() {
c := sse.NewClient("http://127.0.0.1:8888/sse")

// 连接到服务端的时候触发
c.OnConnect(func(ctx context.Context, client *sse.Client) {
hlog.Infof("client1 connect to server %s success with %s method", c.URL, c.Method)
c.SetOnConnectCallback(func(ctx context.Context, client *sse.Client) {
hlog.Infof("client1 connect to server %s success with %s method", c.GetURL(), c.GetMethod())
})

// 服务端断开连接的时候触发
c.OnDisconnect(func(ctx context.Context, client *sse.Client) {
hlog.Infof("client1 disconnect to server %s success with %s method", c.URL, c.Method)
c.SetDisconnectCallback(func(ctx context.Context, client *sse.Client) {
hlog.Infof("client1 disconnect to server %s success with %s method", c.GetURL(), c.GetMethod())
})

events := make(chan *sse.Event)
Expand Down Expand Up @@ -124,13 +124,13 @@ func main() {
c := sse.NewClient("http://127.0.0.1:8888/sse")

// 连接到服务端的时候触发
c.OnConnect(func(ctx context.Context, client *sse.Client) {
hlog.Infof("client2 %s connect to server success with %s method", c.URL, c.Method)
c.SetOnConnectCallback(func(ctx context.Context, client *sse.Client) {
hlog.Infof("client2 %s connect to server success with %s method",c.GetURL(), c.GetMethod())
})

// 服务端断开连接的时候触发
c.OnDisconnect(func(ctx context.Context, client *sse.Client) {
hlog.Infof("client2 %s disconnect to server success with %s method", c.URL, c.Method)
c.SetDisconnectCallback(func(ctx context.Context, client *sse.Client) {
hlog.Infof("client2 %s disconnect to server success with %s method", c.GetURL(), c.GetMethod())
})

events := make(chan *sse.Event)
Expand Down
109 changes: 82 additions & 27 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,34 +42,34 @@ var (
// ConnCallback defines a function to be called on a particular connection event
type ConnCallback func(ctx context.Context, client *Client)

// ResponseValidator validates a response
type ResponseValidator func(ctx context.Context, req *protocol.Request, resp *protocol.Response) error
// ResponseCallback validates a response
type ResponseCallback func(ctx context.Context, req *protocol.Request, resp *protocol.Response) error

// Client handles an incoming server stream
type Client struct {
HertzClient *client.Client
hertzClient *client.Client
disconnectCallback ConnCallback
connectedCallback ConnCallback
ResponseValidator ResponseValidator
Headers map[string]string
URL string
Method string
responseCallback ResponseCallback
headers map[string]string
url string
method string
maxBufferSize int
connected bool
EncodingBase64 bool
LastEventID atomic.Value // []byte
encodingBase64 bool
lastEventID atomic.Value // []byte
}

var defaultClient, _ = client.NewClient(client.WithDialer(standard.NewDialer()), client.WithResponseBodyStream(true))

// NewClient creates a new client
func NewClient(url string) *Client {
c := &Client{
URL: url,
HertzClient: defaultClient,
Headers: make(map[string]string),
url: url,
hertzClient: defaultClient,
headers: make(map[string]string),
maxBufferSize: 1 << 16,
Method: consts.MethodGet,
method: consts.MethodGet,
}

return c
Expand All @@ -91,8 +91,8 @@ func (c *Client) SubscribeWithContext(ctx context.Context, handler func(msg *Eve
protocol.ReleaseRequest(req)
protocol.ReleaseResponse(resp)
}()
if validator := c.ResponseValidator; validator != nil {
err = validator(ctx, req, resp)
if Callback := c.responseCallback; Callback != nil {
err = Callback(ctx, req, resp)
if err != nil {
return err
}
Expand Down Expand Up @@ -147,9 +147,9 @@ func (c *Client) readLoop(ctx context.Context, reader *EventStreamReader, outCh
var msg *Event
if msg, err = c.processEvent(event); err == nil {
if len(msg.ID) > 0 {
c.LastEventID.Store(msg.ID)
c.lastEventID.Store(msg.ID)
} else {
msg.ID, _ = c.LastEventID.Load().(string)
msg.ID, _ = c.lastEventID.Load().(string)
}

// Send downstream if the event has something useful
Expand All @@ -160,13 +160,13 @@ func (c *Client) readLoop(ctx context.Context, reader *EventStreamReader, outCh
}
}

// OnDisconnect specifies the function to run when the connection disconnects
func (c *Client) OnDisconnect(fn ConnCallback) {
// SetDisconnectCallback specifies the function to run when the connection disconnects
func (c *Client) SetDisconnectCallback(fn ConnCallback) {
c.disconnectCallback = fn
}

// OnConnect specifies the function to run when the connection is successful
func (c *Client) OnConnect(fn ConnCallback) {
// SetOnConnectCallback specifies the function to run when the connection is successful
func (c *Client) SetOnConnectCallback(fn ConnCallback) {
c.connectedCallback = fn
}

Expand All @@ -175,24 +175,79 @@ func (c *Client) SetMaxBufferSize(size int) {
c.maxBufferSize = size
}

// SetURL set sse client url
func (c *Client) SetURL(url string) {
c.url = url
}

// SetMethod set sse client request method
func (c *Client) SetMethod(method string) {
c.method = method
}

// SetHeaders set sse client headers
func (c *Client) SetHeaders(headers map[string]string) {
c.headers = headers
}

// SetResponseCallback set sse client responseCallback
func (c *Client) SetResponseCallback(responseCallback ResponseCallback) {
c.responseCallback = responseCallback
}

// SetHertzClient set sse client
func (c *Client) SetHertzClient(hertzClient *client.Client) {
c.hertzClient = hertzClient
}

// SetEncodingBase64 set sse client whether use the base64
func (c *Client) SetEncodingBase64(encodingBase64 bool) {
c.encodingBase64 = encodingBase64
}

// GetURL get sse client url
func (c *Client) GetURL() string {
return c.url
}

// GetHeaders get sse client headers
func (c *Client) GetHeaders() map[string]string {
return c.headers
}

// GetMethod get sse client method
func (c *Client) GetMethod() string {
return c.method
}

// GetHertzClient get sse client
func (c *Client) GetHertzClient() *client.Client {
return c.hertzClient
}

// GetLastEventID get sse client lastEventID
func (c *Client) GetLastEventID() []byte {
return c.lastEventID.Load().([]byte)
}

func (c *Client) request(ctx context.Context, req *protocol.Request, resp *protocol.Response) error {
req.SetMethod(c.Method)
req.SetRequestURI(c.URL)
req.SetMethod(c.method)
req.SetRequestURI(c.url)

req.Header.Set("Cache-Control", "no-cache")
req.Header.Set("Accept", "text/event-stream")
req.Header.Set("Connection", "keep-alive")

lastID, exists := c.LastEventID.Load().([]byte)
lastID, exists := c.lastEventID.Load().([]byte)
if exists && lastID != nil {
req.Header.Set(LastEventID, string(lastID))
}
// Add user specified headers
for k, v := range c.Headers {
for k, v := range c.headers {
req.Header.Set(k, v)
}

err := c.HertzClient.Do(ctx, req, resp)
err := c.hertzClient.Do(ctx, req, resp)
return err
}

Expand Down Expand Up @@ -227,7 +282,7 @@ func (c *Client) processEvent(msg []byte) (event *Event, err error) {
// Trim the last "\n" per the spec.
e.Data = bytes.TrimSuffix(e.Data, []byte("\n"))

if c.EncodingBase64 {
if c.encodingBase64 {
buf := make([]byte, base64.StdEncoding.DecodedLen(len(e.Data)))

n, err := base64.StdEncoding.Decode(buf, e.Data)
Expand Down
2 changes: 1 addition & 1 deletion client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ func TestClientOnConnect(t *testing.T) {
c := NewClient("http://127.0.0.1:9000/sse")

called := make(chan struct{})
c.OnConnect(func(ctx context.Context, client *Client) {
c.SetOnConnectCallback(func(ctx context.Context, client *Client) {
called <- struct{}{}
})

Expand Down
16 changes: 8 additions & 8 deletions examples/client/quickstart/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,13 @@ func main() {
c := sse.NewClient("http://127.0.0.1:8888/sse")

// touch off when connected to the server
c.OnConnect(func(ctx context.Context, client *sse.Client) {
hlog.Infof("client1 connect to server %s success with %s method", c.URL, c.Method)
c.SetOnConnectCallback(func(ctx context.Context, client *sse.Client) {
hlog.Infof("client1 connect to server %s success with %s method", c.GetURL(), c.GetMethod())
})

// touch off when the connection is shutdown
c.OnDisconnect(func(ctx context.Context, client *sse.Client) {
hlog.Infof("client1 disconnect to server %s success with %s method", c.URL, c.Method)
c.SetDisconnectCallback(func(ctx context.Context, client *sse.Client) {
hlog.Infof("client1 disconnect to server %s success with %s method", c.GetURL(), c.GetMethod())
})

events := make(chan *sse.Event)
Expand Down Expand Up @@ -89,13 +89,13 @@ func main() {
c := sse.NewClient("http://127.0.0.1:8888/sse")

// touch off when connected to the server
c.OnConnect(func(ctx context.Context, client *sse.Client) {
hlog.Infof("client2 %s connect to server success with %s method", c.URL, c.Method)
c.SetOnConnectCallback(func(ctx context.Context, client *sse.Client) {
hlog.Infof("client2 %s connect to server success with %s method", c.GetURL(), c.GetMethod())
})

// touch off when the connection is shutdown
c.OnDisconnect(func(ctx context.Context, client *sse.Client) {
hlog.Infof("client2 %s disconnect to server success with %s method", c.URL, c.Method)
c.SetDisconnectCallback(func(ctx context.Context, client *sse.Client) {
hlog.Infof("client2 %s disconnect to server success with %s method", c.GetURL(), c.GetMethod())
})

events := make(chan *sse.Event)
Expand Down

0 comments on commit 6305346

Please sign in to comment.