diff --git a/httpx/websocket.go b/httpx/websocket.go index 7fe9a74..8fd5022 100644 --- a/httpx/websocket.go +++ b/httpx/websocket.go @@ -17,6 +17,9 @@ const ( // how often to send a ping message pingPeriod = 30 * time.Second + + // maximum time to wait for writer to drain when closing + drainPeriod = 3 * time.Second ) var upgrader = websocket.Upgrader{ @@ -45,18 +48,28 @@ type WebSocket interface { OnClose(func(int)) } +type message struct { + type_ int + data []byte +} + // WebSocket implemention using gorilla library type socket struct { - conn *websocket.Conn - onMessage func([]byte) - onClose func(int) - outbox chan []byte - readError chan error - writeError chan error - stopWriter chan bool - closingWithCode int - rwWaitGroup sync.WaitGroup + conn *websocket.Conn + outbox chan message + + readError chan error + writeError chan error + shutdown chan bool + stopWriter chan bool + closingWithCode int + + readerWaitGroup sync.WaitGroup + writerWaitGroup sync.WaitGroup monitorWaitGroup sync.WaitGroup + + onMessage func([]byte) + onClose func(int) } // NewWebSocket creates a new web socket from a regular HTTP request @@ -69,13 +82,16 @@ func NewWebSocket(w http.ResponseWriter, r *http.Request, maxReadBytes int64, se conn.SetReadLimit(maxReadBytes) return &socket{ - conn: conn, - onMessage: func([]byte) {}, - onClose: func(int) {}, - outbox: make(chan []byte, sendBuffer), + conn: conn, + outbox: make(chan message, sendBuffer), + readError: make(chan error, 1), writeError: make(chan error, 1), - stopWriter: make(chan bool, 1), + shutdown: make(chan bool, 1), + stopWriter: make(chan bool), + + onMessage: defaultOnMessage, + onClose: defaultOnClose, }, nil } @@ -83,6 +99,10 @@ func (s *socket) OnMessage(fn func([]byte)) { s.onMessage = fn } func (s *socket) OnClose(fn func(int)) { s.onClose = fn } func (s *socket) Start() { + if s.closingWithCode != 0 { + panic("can't start socket which is closed or closing") + } + s.conn.SetReadDeadline(time.Now().Add(maxReadWait)) s.conn.SetPongHandler(s.pong) @@ -92,14 +112,23 @@ func (s *socket) Start() { } func (s *socket) Send(msg []byte) { - s.outbox <- msg + if s.closingWithCode != 0 { + panic("can't send to socket which is closed or closing") + } + + s.outbox <- message{type_: websocket.TextMessage, data: msg} } func (s *socket) Close(code int) { + if s.closingWithCode != 0 { + panic("can't close socket which is already closed or closing") + } + s.closingWithCode = code - s.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(code, "")) - s.conn.Close() // causes reader to stop - s.stopWriter <- true + + s.outbox <- message{type_: websocket.CloseMessage, data: websocket.FormatCloseMessage(code, "")} + + s.shutdown <- true s.monitorWaitGroup.Wait() } @@ -114,6 +143,7 @@ func (s *socket) monitor() { s.monitorWaitGroup.Add(1) defer s.monitorWaitGroup.Done() + // shutdown starts via read error, write error, or Close() out: for { select { @@ -121,25 +151,31 @@ out: if e, ok := err.(*websocket.CloseError); ok && s.closingWithCode == 0 { s.closingWithCode = e.Code } - s.stopWriter <- true // ensure writer is stopped break out case err := <-s.writeError: - if e, ok := err.(*websocket.CloseError); ok { + if e, ok := err.(*websocket.CloseError); ok && s.closingWithCode == 0 { s.closingWithCode = e.Code } - s.conn.Close() // ensure reader is stopped + break out + case <-s.shutdown: break out } } - s.rwWaitGroup.Wait() + // stop writer if not already finished... + s.stopWriter <- true + s.writerWaitGroup.Wait() + + // stop reader if not already finished... + s.conn.Close() + s.readerWaitGroup.Wait() s.onClose(s.closingWithCode) } func (s *socket) reader() { - s.rwWaitGroup.Add(1) - defer s.rwWaitGroup.Done() + s.readerWaitGroup.Add(1) + defer s.readerWaitGroup.Done() for { _, message, err := s.conn.ReadMessage() @@ -153,31 +189,49 @@ func (s *socket) reader() { } func (s *socket) writer() { - s.rwWaitGroup.Add(1) - defer s.rwWaitGroup.Done() + s.writerWaitGroup.Add(1) + defer s.writerWaitGroup.Done() ticker := time.NewTicker(pingPeriod) defer ticker.Stop() +out: for { select { case msg := <-s.outbox: s.conn.SetWriteDeadline(time.Now().Add(maxWriteWait)) - err := s.conn.WriteMessage(websocket.TextMessage, msg) - if err != nil { + if err := s.conn.WriteMessage(msg.type_, msg.data); err != nil { s.writeError <- err - return } case <-ticker.C: s.conn.SetWriteDeadline(time.Now().Add(maxWriteWait)) if err := s.conn.WriteMessage(websocket.PingMessage, nil); err != nil { s.writeError <- err - return } case <-s.stopWriter: - return + break out + } + } + + // try to drain the outbox with a time limit + if len(s.outbox) > 0 { + s.conn.SetWriteDeadline(time.Now().Add(drainPeriod)) + for { + select { + case msg := <-s.outbox: + err := s.conn.WriteMessage(msg.type_, msg.data) + if err != nil || len(s.outbox) == 0 { + return + } + case <-time.After(drainPeriod): + return + } } } } + +func defaultOnMessage([]byte) {} + +func defaultOnClose(int) {} diff --git a/httpx/websocket_test.go b/httpx/websocket_test.go index 17a23eb..c9a3f1f 100644 --- a/httpx/websocket_test.go +++ b/httpx/websocket_test.go @@ -97,12 +97,25 @@ func TestSocketMessages(t *testing.T) { return nil }) + sock.Send([]byte("closing time")) sock.Close(1001) + conn.ReadMessage() // read the final message conn.ReadMessage() // read the close message assert.Equal(t, 1001, serverCloseCode) assert.Equal(t, 1001, connCloseCode) + + // check we can no longer send to the socket or close it again, or restart it + assert.Panics(t, func() { + sock.Send([]byte("x")) + }) + assert.Panics(t, func() { + sock.Close(1000) + }) + assert.Panics(t, func() { + sock.Start() + }) } func TestSocketClientCloseWithMessage(t *testing.T) {