diff --git a/client.go b/client.go index 6e82bcfe..42cbc1e7 100644 --- a/client.go +++ b/client.go @@ -9,8 +9,10 @@ import ( "encoding/json" "fmt" "io" + "math/rand" "net" "net/http" + "sync" "time" "golang.org/x/net/http2" @@ -33,6 +35,9 @@ var ( // HTTPClient. The timeout includes connection time, any redirects, // and reading the response body. HTTPClientTimeout = 30 * time.Second + // PingPongFrequency is the interval with which a client will PING APNs + // servers. + PingPongFrequency = 15 * time.Second ) // Client represents a connection with the APNs @@ -40,6 +45,11 @@ type Client struct { HTTPClient *http.Client Certificate tls.Certificate Host string + conn net.Conn + pinging bool + newConnChan chan struct{} + stopChan chan struct{} + m *sync.Mutex } // NewClient returns a new Client with an underlying http.Client configured with @@ -53,27 +63,44 @@ type Client struct { // // If your use case involves multiple long-lived connections, consider using // the ClientManager, which manages clients for you. -func NewClient(certificate tls.Certificate) *Client { +// +// Alternatively, you can keep the clients connection healthy by calling +// EnablePinging, which will send PING frames to APNs servers with the interval +// specified via PingPongFrequency. +func NewClient(certificate tls.Certificate) (client *Client) { tlsConfig := &tls.Config{ Certificates: []tls.Certificate{certificate}, } if len(certificate.Certificate) > 0 { tlsConfig.BuildNameToCertificate() } + client = &Client{ + Certificate: certificate, + Host: DefaultHost, + newConnChan: make(chan struct{}), + stopChan: make(chan struct{}), + m: new(sync.Mutex), + } transport := &http2.Transport{ TLSClientConfig: tlsConfig, - DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) { - return tls.DialWithDialer(&net.Dialer{Timeout: TLSDialTimeout}, network, addr, cfg) + DialTLS: func(network, addr string, cfg *tls.Config) (c net.Conn, e error) { + c, e = tls.DialWithDialer(&net.Dialer{Timeout: TLSDialTimeout}, network, addr, cfg) + if e == nil { + client.m.Lock() + defer client.m.Unlock() + client.conn = c + if client.pinging { + client.newConnChan <- struct{}{} + } + } + return }, } - return &Client{ - HTTPClient: &http.Client{ - Transport: transport, - Timeout: HTTPClientTimeout, - }, - Certificate: certificate, - Host: DefaultHost, + client.HTTPClient = &http.Client{ + Transport: transport, + Timeout: HTTPClientTimeout, } + return } // Development sets the Client to use the APNs development push endpoint. @@ -120,6 +147,84 @@ func (c *Client) Push(n *Notification) (*Response, error) { return response, nil } +// EnablePinging tries to send PING frames to APNs servers whenever the client +// has a valid connection. If the willHandleDrops parameter is set to true, this +// function returns a read-only channel that gets notified when pinging fails. +// This allows the user to take actions to preemptively reinitialize the client's +// connection. The second return value indicates whether the call has successfully +// enabled pinging. +func (c *Client) EnablePinging(willHandleDrops bool) (<-chan struct{}, bool) { + c.m.Lock() + defer c.m.Unlock() + if c.pinging { + return nil, false + } + c.pinging = true + var dropSignal chan struct{} + if willHandleDrops { + dropSignal = make(chan struct{}) + } + go func() { + // 8 bytes of random data used for PING-PONG, as per HTTP/2 spec. + var data [8]byte + rand.Read(data[:]) + pinger := new(time.Ticker) + var framer *http2.Framer + c.m.Lock() + if c.conn != nil { + framer = http2.NewFramer(c.conn, c.conn) + pinger = time.NewTicker(PingPongFrequency) + } + c.m.Unlock() + for { + select { + case <-pinger.C: + err := framer.WritePing(false, data) + if err != nil { + // Could not PING the APNs server, stop trying + // and notify the drop handler, if there is any. + c.m.Lock() + c.conn = nil + c.m.Unlock() + framer = nil + pinger.Stop() + if willHandleDrops { + dropSignal <- struct{}{} + } + } + case <-c.newConnChan: + c.m.Lock() + framer = http2.NewFramer(c.conn, c.conn) + c.m.Unlock() + pinger.Stop() + pinger = time.NewTicker(PingPongFrequency) + case <-c.stopChan: + pinger.Stop() + c.m.Lock() + defer c.m.Unlock() + c.conn = nil + framer = nil + return + } + } + }() + return dropSignal, true +} + +// DisablePinging stops the pinging operation associated with the client, if +// there's any, and returns a boolean that indicates if the call has successfully +// stopped the pinging operation. +func (c *Client) DisablePinging() bool { + c.m.Lock() + defer c.m.Unlock() + if c.pinging { + c.pinging = false + c.stopChan <- struct{}{} + return true + } + return false +} + func setHeaders(r *http.Request, n *Notification) { r.Header.Set("Content-Type", "application/json; charset=utf-8") if n.Topic != "" { diff --git a/client_test.go b/client_test.go index 669e22e2..ee841810 100644 --- a/client_test.go +++ b/client_test.go @@ -8,6 +8,7 @@ import ( "net/http" "net/http/httptest" "strings" + "sync/atomic" "testing" "time" @@ -233,3 +234,71 @@ func TestMalformedJSONResponse(t *testing.T) { assert.Error(t, err) assert.Equal(t, false, res.Sent()) } + +func TestEnablePinging(t *testing.T) { + apns.PingPongFrequency = 50 * time.Millisecond + apns.TLSDialTimeout = 10 * time.Second + n := mockNotification() + server := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) + http2.ConfigureServer(server.Config, nil) + server.TLS = server.Config.TLSConfig + server.StartTLS() + transport := &http.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}} + http2.ConfigureTransport(transport) + certificate, _ := certificate.FromP12File("certificate/_fixtures/certificate-valid.p12", "") + client := apns.NewClient(certificate) + client.Host = server.URL + client.HTTPClient.Transport.(*http2.Transport).TLSClientConfig = transport.TLSClientConfig + client.HTTPClient = &http.Client{Transport: client.HTTPClient.Transport} + drop, ok := client.EnablePinging(true) + assert.Equal(t, true, ok) + var gotDropped int32 + go func() { + <-drop + atomic.StoreInt32(&gotDropped, 1) + }() + _, err := client.Push(n) + assert.NoError(t, err) + time.Sleep(100 * time.Millisecond) + assert.Equal(t, 0, int(atomic.LoadInt32(&gotDropped))) + server.Close() + time.Sleep(100 * time.Millisecond) + assert.Equal(t, 1, int(atomic.LoadInt32(&gotDropped))) +} + +func TestDisablePinging(t *testing.T) { + n := mockNotification() + server := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) + http2.ConfigureServer(server.Config, nil) + server.TLS = server.Config.TLSConfig + server.StartTLS() + transport := &http.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}} + http2.ConfigureTransport(transport) + certificate, _ := certificate.FromP12File("certificate/_fixtures/certificate-valid.p12", "") + client := apns.NewClient(certificate) + client.Host = server.URL + client.HTTPClient.Transport.(*http2.Transport).TLSClientConfig = transport.TLSClientConfig + client.HTTPClient = &http.Client{Transport: client.HTTPClient.Transport} + drop, ok := client.EnablePinging(true) + assert.Equal(t, true, ok) + var gotDropped int32 + cleanUp := make(chan struct{}) + go func() { + select { + case <-drop: + atomic.StoreInt32(&gotDropped, 1) + case <-cleanUp: + return + } + }() + _, err := client.Push(n) + assert.NoError(t, err) + time.Sleep(100 * time.Millisecond) + assert.Equal(t, 0, int(atomic.LoadInt32(&gotDropped))) + ok = client.DisablePinging() + assert.Equal(t, true, ok) + server.Close() + time.Sleep(100 * time.Millisecond) + assert.Equal(t, 0, int(atomic.LoadInt32(&gotDropped))) + close(cleanUp) +}