Skip to content

Commit

Permalink
Merge pull request #110 from nyaruka/websocket_closing
Browse files Browse the repository at this point in the history
Fix controlled closing of websockets
  • Loading branch information
rowanseymour authored Jan 15, 2024
2 parents 6fc7181 + 6daeb9a commit 20e83c1
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 31 deletions.
116 changes: 85 additions & 31 deletions httpx/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ const (

// how often to send a ping message
pingPeriod = 30 * time.Second

// maximum time to wait for writer to drain when closing
drainPeriod = 3 * time.Second
)

var upgrader = websocket.Upgrader{
Expand Down Expand Up @@ -45,18 +48,28 @@ type WebSocket interface {
OnClose(func(int))
}

type message struct {
type_ int
data []byte
}

// WebSocket implemention using gorilla library
type socket struct {
conn *websocket.Conn
onMessage func([]byte)
onClose func(int)
outbox chan []byte
readError chan error
writeError chan error
stopWriter chan bool
closingWithCode int
rwWaitGroup sync.WaitGroup
conn *websocket.Conn
outbox chan message

readError chan error
writeError chan error
shutdown chan bool
stopWriter chan bool
closingWithCode int

readerWaitGroup sync.WaitGroup
writerWaitGroup sync.WaitGroup
monitorWaitGroup sync.WaitGroup

onMessage func([]byte)
onClose func(int)
}

// NewWebSocket creates a new web socket from a regular HTTP request
Expand All @@ -69,20 +82,27 @@ func NewWebSocket(w http.ResponseWriter, r *http.Request, maxReadBytes int64, se
conn.SetReadLimit(maxReadBytes)

return &socket{
conn: conn,
onMessage: func([]byte) {},
onClose: func(int) {},
outbox: make(chan []byte, sendBuffer),
conn: conn,
outbox: make(chan message, sendBuffer),

readError: make(chan error, 1),
writeError: make(chan error, 1),
stopWriter: make(chan bool, 1),
shutdown: make(chan bool, 1),
stopWriter: make(chan bool),

onMessage: defaultOnMessage,
onClose: defaultOnClose,
}, nil
}

func (s *socket) OnMessage(fn func([]byte)) { s.onMessage = fn }
func (s *socket) OnClose(fn func(int)) { s.onClose = fn }

func (s *socket) Start() {
if s.closingWithCode != 0 {
panic("can't start socket which is closed or closing")
}

s.conn.SetReadDeadline(time.Now().Add(maxReadWait))
s.conn.SetPongHandler(s.pong)

Expand All @@ -92,14 +112,23 @@ func (s *socket) Start() {
}

func (s *socket) Send(msg []byte) {
s.outbox <- msg
if s.closingWithCode != 0 {
panic("can't send to socket which is closed or closing")
}

s.outbox <- message{type_: websocket.TextMessage, data: msg}
}

func (s *socket) Close(code int) {
if s.closingWithCode != 0 {
panic("can't close socket which is already closed or closing")
}

s.closingWithCode = code
s.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(code, ""))
s.conn.Close() // causes reader to stop
s.stopWriter <- true

s.outbox <- message{type_: websocket.CloseMessage, data: websocket.FormatCloseMessage(code, "")}

s.shutdown <- true

s.monitorWaitGroup.Wait()
}
Expand All @@ -114,32 +143,39 @@ func (s *socket) monitor() {
s.monitorWaitGroup.Add(1)
defer s.monitorWaitGroup.Done()

// shutdown starts via read error, write error, or Close()
out:
for {
select {
case err := <-s.readError:
if e, ok := err.(*websocket.CloseError); ok && s.closingWithCode == 0 {
s.closingWithCode = e.Code
}
s.stopWriter <- true // ensure writer is stopped
break out
case err := <-s.writeError:
if e, ok := err.(*websocket.CloseError); ok {
if e, ok := err.(*websocket.CloseError); ok && s.closingWithCode == 0 {
s.closingWithCode = e.Code
}
s.conn.Close() // ensure reader is stopped
break out
case <-s.shutdown:
break out
}
}

s.rwWaitGroup.Wait()
// stop writer if not already finished...
s.stopWriter <- true
s.writerWaitGroup.Wait()

// stop reader if not already finished...
s.conn.Close()
s.readerWaitGroup.Wait()

s.onClose(s.closingWithCode)
}

func (s *socket) reader() {
s.rwWaitGroup.Add(1)
defer s.rwWaitGroup.Done()
s.readerWaitGroup.Add(1)
defer s.readerWaitGroup.Done()

for {
_, message, err := s.conn.ReadMessage()
Expand All @@ -153,31 +189,49 @@ func (s *socket) reader() {
}

func (s *socket) writer() {
s.rwWaitGroup.Add(1)
defer s.rwWaitGroup.Done()
s.writerWaitGroup.Add(1)
defer s.writerWaitGroup.Done()

ticker := time.NewTicker(pingPeriod)
defer ticker.Stop()

out:
for {
select {
case msg := <-s.outbox:
s.conn.SetWriteDeadline(time.Now().Add(maxWriteWait))

err := s.conn.WriteMessage(websocket.TextMessage, msg)
if err != nil {
if err := s.conn.WriteMessage(msg.type_, msg.data); err != nil {
s.writeError <- err
return
}
case <-ticker.C:
s.conn.SetWriteDeadline(time.Now().Add(maxWriteWait))

if err := s.conn.WriteMessage(websocket.PingMessage, nil); err != nil {
s.writeError <- err
return
}
case <-s.stopWriter:
return
break out
}
}

// try to drain the outbox with a time limit
if len(s.outbox) > 0 {
s.conn.SetWriteDeadline(time.Now().Add(drainPeriod))
for {
select {
case msg := <-s.outbox:
err := s.conn.WriteMessage(msg.type_, msg.data)
if err != nil || len(s.outbox) == 0 {
return
}
case <-time.After(drainPeriod):
return
}
}
}
}

func defaultOnMessage([]byte) {}

func defaultOnClose(int) {}
13 changes: 13 additions & 0 deletions httpx/websocket_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,12 +97,25 @@ func TestSocketMessages(t *testing.T) {
return nil
})

sock.Send([]byte("closing time"))
sock.Close(1001)

conn.ReadMessage() // read the final message
conn.ReadMessage() // read the close message

assert.Equal(t, 1001, serverCloseCode)
assert.Equal(t, 1001, connCloseCode)

// check we can no longer send to the socket or close it again, or restart it
assert.Panics(t, func() {
sock.Send([]byte("x"))
})
assert.Panics(t, func() {
sock.Close(1000)
})
assert.Panics(t, func() {
sock.Start()
})
}

func TestSocketClientCloseWithMessage(t *testing.T) {
Expand Down

0 comments on commit 20e83c1

Please sign in to comment.