Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add ping pong functionality #45

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 115 additions & 10 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@ import (
"encoding/json"
"fmt"
"io"
"math/rand"
"net"
"net/http"
"sync"
"time"

"golang.org/x/net/http2"
Expand All @@ -33,13 +35,21 @@ var (
// HTTPClient. The timeout includes connection time, any redirects,
// and reading the response body.
HTTPClientTimeout = 30 * time.Second
// PingPongFrequency is the interval with which a client will PING APNs
// servers.
PingPongFrequency = 15 * time.Second
)

// Client represents a connection with the APNs
type Client struct {
HTTPClient *http.Client
Certificate tls.Certificate
Host string
conn net.Conn
pinging bool
newConnChan chan struct{}
stopChan chan struct{}
m *sync.Mutex
}

// NewClient returns a new Client with an underlying http.Client configured with
Expand All @@ -53,27 +63,44 @@ type Client struct {
//
// If your use case involves multiple long-lived connections, consider using
// the ClientManager, which manages clients for you.
func NewClient(certificate tls.Certificate) *Client {
//
// Alternatively, you can keep the clients connection healthy by calling
// EnablePinging, which will send PING frames to APNs servers with the interval
// specified via PingPongFrequency.
func NewClient(certificate tls.Certificate) (client *Client) {
tlsConfig := &tls.Config{
Certificates: []tls.Certificate{certificate},
}
if len(certificate.Certificate) > 0 {
tlsConfig.BuildNameToCertificate()
}
client = &Client{
Certificate: certificate,
Host: DefaultHost,
newConnChan: make(chan struct{}),
stopChan: make(chan struct{}),
m: new(sync.Mutex),
}
transport := &http2.Transport{
TLSClientConfig: tlsConfig,
DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) {
return tls.DialWithDialer(&net.Dialer{Timeout: TLSDialTimeout}, network, addr, cfg)
DialTLS: func(network, addr string, cfg *tls.Config) (c net.Conn, e error) {
c, e = tls.DialWithDialer(&net.Dialer{Timeout: TLSDialTimeout}, network, addr, cfg)
if e == nil {
client.m.Lock()
defer client.m.Unlock()
client.conn = c
if client.pinging {
client.newConnChan <- struct{}{}

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's blocks forever if we never call EnablePinging. Is it ok here? (Maybe I miss some context, sorry.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It shouldn't since when pinging is false, the function won't try to send a signal to newConnChan, and releases the lock when it returns. When pinging is true, there is a goroutine that reads from newConnChan, so writing to it should never block. Am I missing something?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I see. Sorry again :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

S'all good, the more eyes the better!

}
}
return
},
}
return &Client{
HTTPClient: &http.Client{
Transport: transport,
Timeout: HTTPClientTimeout,
},
Certificate: certificate,
Host: DefaultHost,
client.HTTPClient = &http.Client{
Transport: transport,
Timeout: HTTPClientTimeout,
}
return
}

// Development sets the Client to use the APNs development push endpoint.
Expand Down Expand Up @@ -120,6 +147,84 @@ func (c *Client) Push(n *Notification) (*Response, error) {
return response, nil
}

// EnablePinging tries to send PING frames to APNs servers whenever the client
// has a valid connection. If the willHandleDrops parameter is set to true, this
// function returns a read-only channel that gets notified when pinging fails.
// This allows the user to take actions to preemptively reinitialize the client's
// connection. The second return value indicates whether the call has successfully
// enabled pinging.
func (c *Client) EnablePinging(willHandleDrops bool) (<-chan struct{}, bool) {
c.m.Lock()
defer c.m.Unlock()
if c.pinging {
return nil, false
}
c.pinging = true
var dropSignal chan struct{}
if willHandleDrops {
dropSignal = make(chan struct{})
}
go func() {
// 8 bytes of random data used for PING-PONG, as per HTTP/2 spec.
var data [8]byte
rand.Read(data[:])
pinger := new(time.Ticker)
var framer *http2.Framer
c.m.Lock()
if c.conn != nil {
framer = http2.NewFramer(c.conn, c.conn)
pinger = time.NewTicker(PingPongFrequency)
}
c.m.Unlock()
for {
select {
case <-pinger.C:
err := framer.WritePing(false, data)
if err != nil {
// Could not PING the APNs server, stop trying
// and notify the drop handler, if there is any.
c.m.Lock()
c.conn = nil
c.m.Unlock()
framer = nil
pinger.Stop()
if willHandleDrops {
dropSignal <- struct{}{}
}
}
case <-c.newConnChan:
c.m.Lock()
framer = http2.NewFramer(c.conn, c.conn)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Couldn't the transport technically dial multiple connections? I guess in practice it doesn't, but is there anything in place stopping it from doing so?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The source code and many production tests reveal that it does not.

c.m.Unlock()
pinger.Stop()
pinger = time.NewTicker(PingPongFrequency)
case <-c.stopChan:
pinger.Stop()
c.m.Lock()
defer c.m.Unlock()
c.conn = nil
framer = nil
return
}
}
}()
return dropSignal, true
}

// DisablePinging stops the pinging operation associated with the client, if
// there's any, and returns a boolean that indicates if the call has successfully
// stopped the pinging operation.
func (c *Client) DisablePinging() bool {
c.m.Lock()
defer c.m.Unlock()
if c.pinging {
c.pinging = false
c.stopChan <- struct{}{}
return true
}
return false
}

func setHeaders(r *http.Request, n *Notification) {
r.Header.Set("Content-Type", "application/json; charset=utf-8")
if n.Topic != "" {
Expand Down
69 changes: 69 additions & 0 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"net/http"
"net/http/httptest"
"strings"
"sync/atomic"
"testing"
"time"

Expand Down Expand Up @@ -233,3 +234,71 @@ func TestMalformedJSONResponse(t *testing.T) {
assert.Error(t, err)
assert.Equal(t, false, res.Sent())
}

func TestEnablePinging(t *testing.T) {
apns.PingPongFrequency = 50 * time.Millisecond
apns.TLSDialTimeout = 10 * time.Second
n := mockNotification()
server := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
http2.ConfigureServer(server.Config, nil)
server.TLS = server.Config.TLSConfig
server.StartTLS()
transport := &http.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}}
http2.ConfigureTransport(transport)
certificate, _ := certificate.FromP12File("certificate/_fixtures/certificate-valid.p12", "")
client := apns.NewClient(certificate)
client.Host = server.URL
client.HTTPClient.Transport.(*http2.Transport).TLSClientConfig = transport.TLSClientConfig
client.HTTPClient = &http.Client{Transport: client.HTTPClient.Transport}
drop, ok := client.EnablePinging(true)
assert.Equal(t, true, ok)
var gotDropped int32
go func() {
<-drop
atomic.StoreInt32(&gotDropped, 1)
}()
_, err := client.Push(n)
assert.NoError(t, err)
time.Sleep(100 * time.Millisecond)
assert.Equal(t, 0, int(atomic.LoadInt32(&gotDropped)))
server.Close()
time.Sleep(100 * time.Millisecond)
assert.Equal(t, 1, int(atomic.LoadInt32(&gotDropped)))
}

func TestDisablePinging(t *testing.T) {
n := mockNotification()
server := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
http2.ConfigureServer(server.Config, nil)
server.TLS = server.Config.TLSConfig
server.StartTLS()
transport := &http.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}}
http2.ConfigureTransport(transport)
certificate, _ := certificate.FromP12File("certificate/_fixtures/certificate-valid.p12", "")
client := apns.NewClient(certificate)
client.Host = server.URL
client.HTTPClient.Transport.(*http2.Transport).TLSClientConfig = transport.TLSClientConfig
client.HTTPClient = &http.Client{Transport: client.HTTPClient.Transport}
drop, ok := client.EnablePinging(true)
assert.Equal(t, true, ok)
var gotDropped int32
cleanUp := make(chan struct{})
go func() {
select {
case <-drop:
atomic.StoreInt32(&gotDropped, 1)
case <-cleanUp:
return
}
}()
_, err := client.Push(n)
assert.NoError(t, err)
time.Sleep(100 * time.Millisecond)
assert.Equal(t, 0, int(atomic.LoadInt32(&gotDropped)))
ok = client.DisablePinging()
assert.Equal(t, true, ok)
server.Close()
time.Sleep(100 * time.Millisecond)
assert.Equal(t, 0, int(atomic.LoadInt32(&gotDropped)))
close(cleanUp)
}