Skip to content

Commit

Permalink
Merge pull request #108 from nyaruka/websockets
Browse files Browse the repository at this point in the history
Add WebSocket functionality to httpx
  • Loading branch information
rowanseymour authored Jan 12, 2024
2 parents ac4221c + 35142c6 commit 842f638
Show file tree
Hide file tree
Showing 4 changed files with 320 additions and 0 deletions.
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/gorilla/websocket v1.5.1
github.com/jmespath/go-jmespath v0.4.0 // indirect
github.com/leodido/go-urn v1.2.4 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaS
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.8 h1:e6P7q2lk1O+qJJb4BtCQXlK8vWEO8V1ZeuEdJNOqZyg=
github.com/google/go-cmp v0.5.8/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/gorilla/websocket v1.5.1 h1:gmztn0JnHVt9JZquRuzLw3g4wouNVzKL15iLr/zn/QY=
github.com/gorilla/websocket v1.5.1/go.mod h1:x3kM2JMyaluk02fnUJpQuwD2dCS5NDG2ZHL0uE0tcaY=
github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9YPoQUg=
github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo=
github.com/jmespath/go-jmespath/internal/testify v1.5.1 h1:shLQSRRSCCPj3f2gpwzGwWFoC7ycTf1rcQZHOlsJ6N8=
Expand Down
177 changes: 177 additions & 0 deletions httpx/websocket.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
package httpx

import (
"net/http"
"sync"
"time"

"github.com/gorilla/websocket"
)

const (
// max time for between reading a message before socket is considered closed
maxReadWait = 60 * time.Second

// maximum time to wait for message to be written
maxWriteWait = 15 * time.Second

// how often to send a ping message
pingPeriod = 30 * time.Second
)

var upgrader = websocket.Upgrader{ReadBufferSize: 1024, WriteBufferSize: 1024}

// WebSocket provides a websocket interface similar to that of Javascript.
type WebSocket interface {
// Start begins reading and writing of messages on this socket
Start()

// Send sends the given message over the socket
Send([]byte)

// Close closes the socket connection
Close(int)

// OnMessage is called when the socket receives a message
OnMessage(func([]byte))

// OnClose is called when the socket is closed (even if we initiate the close)
OnClose(func(int))
}

// WebSocket implemention using gorilla library
type socket struct {
conn *websocket.Conn
onMessage func([]byte)
onClose func(int)
outbox chan []byte
readError chan error
writeError chan error
stopWriter chan bool
closingWithCode int
rwWaitGroup sync.WaitGroup
monitorWaitGroup sync.WaitGroup
}

// NewWebSocket creates a new web socket from a regular HTTP request
func NewWebSocket(w http.ResponseWriter, r *http.Request, maxReadBytes int64, sendBuffer int) (WebSocket, error) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
return nil, err
}

conn.SetReadLimit(maxReadBytes)

return &socket{
conn: conn,
onMessage: func([]byte) {},
onClose: func(int) {},
outbox: make(chan []byte, sendBuffer),
readError: make(chan error, 1),
writeError: make(chan error, 1),
stopWriter: make(chan bool, 1),
}, nil
}

func (s *socket) OnMessage(fn func([]byte)) { s.onMessage = fn }
func (s *socket) OnClose(fn func(int)) { s.onClose = fn }

func (s *socket) Start() {
s.conn.SetReadDeadline(time.Now().Add(maxReadWait))
s.conn.SetPongHandler(s.pong)

go s.monitor()
go s.reader()
go s.writer()
}

func (s *socket) Send(msg []byte) {
s.outbox <- msg
}

func (s *socket) Close(code int) {
s.closingWithCode = code
s.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(code, ""))
s.conn.Close() // causes reader to stop
s.stopWriter <- true

s.monitorWaitGroup.Wait()
}

func (s *socket) pong(m string) error {
s.conn.SetReadDeadline(time.Now().Add(maxReadWait))

return nil
}

func (s *socket) monitor() {
s.monitorWaitGroup.Add(1)
defer s.monitorWaitGroup.Done()

out:
for {
select {
case err := <-s.readError:
if e, ok := err.(*websocket.CloseError); ok && s.closingWithCode == 0 {
s.closingWithCode = e.Code
}
s.stopWriter <- true // ensure writer is stopped
break out
case err := <-s.writeError:
if e, ok := err.(*websocket.CloseError); ok {
s.closingWithCode = e.Code
}
s.conn.Close() // ensure reader is stopped
break out
}
}

s.rwWaitGroup.Wait()

s.onClose(s.closingWithCode)
}

func (s *socket) reader() {
s.rwWaitGroup.Add(1)
defer s.rwWaitGroup.Done()

for {
_, message, err := s.conn.ReadMessage()
if err != nil {
s.readError <- err
return
}

s.onMessage(message)
}
}

func (s *socket) writer() {
s.rwWaitGroup.Add(1)
defer s.rwWaitGroup.Done()

ticker := time.NewTicker(pingPeriod)
defer ticker.Stop()

for {
select {
case msg := <-s.outbox:
s.conn.SetWriteDeadline(time.Now().Add(maxWriteWait))

err := s.conn.WriteMessage(websocket.TextMessage, msg)
if err != nil {
s.writeError <- err
return
}
case <-ticker.C:
s.conn.SetWriteDeadline(time.Now().Add(maxWriteWait))

if err := s.conn.WriteMessage(websocket.PingMessage, nil); err != nil {
s.writeError <- err
return
}
case <-s.stopWriter:
return
}
}
}
140 changes: 140 additions & 0 deletions httpx/websocket_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
package httpx_test

import (
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"

"github.com/gorilla/websocket"
"github.com/nyaruka/gocommon/httpx"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func newSocketServer(t *testing.T, fn func(httpx.WebSocket)) string {
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
sock, err := httpx.NewWebSocket(w, r, 4096, 5)
require.NoError(t, err)

fn(sock)
}))

return "ws:" + strings.TrimPrefix(s.URL, "http:")
}

func newSocketConnection(t *testing.T, url string) *websocket.Conn {
d := websocket.Dialer{
Subprotocols: []string{"p1", "p2"},
ReadBufferSize: 1024,
WriteBufferSize: 1024,
HandshakeTimeout: 30 * time.Second,
}
c, _, err := d.Dial(url, nil)
assert.NoError(t, err)
return c
}

func TestSocketMessages(t *testing.T) {
var sock httpx.WebSocket
var serverReceived [][]byte
var serverCloseCode int

serverURL := newSocketServer(t, func(ws httpx.WebSocket) {
sock = ws
sock.OnMessage(func(b []byte) {
serverReceived = append(serverReceived, b)
})
sock.OnClose(func(code int) {
serverCloseCode = code
})
sock.Start()
})

conn := newSocketConnection(t, serverURL)

// send a message from the server...
sock.Send([]byte("from server"))

// and read it from the client
msgType, msg, err := conn.ReadMessage()
assert.NoError(t, err)
assert.Equal(t, 1, msgType)
assert.Equal(t, "from server", string(msg))

// send a message from the client...
conn.WriteMessage(websocket.TextMessage, []byte("to server"))

// and check server received it
time.Sleep(500 * time.Millisecond)
assert.Equal(t, [][]byte{[]byte("to server")}, serverReceived)

pongReceived := false
conn.SetPongHandler(func(appData string) error {
pongReceived = true
return nil
})

// send a ping message from the client...
conn.WriteMessage(websocket.PingMessage, []byte{})

// and give server time to receive it and respond
time.Sleep(500 * time.Millisecond)

// give the connection something to read because ReadMessage will block until it gets a non-ping-pong message
sock.Send([]byte("dummy"))
conn.ReadMessage()

assert.True(t, pongReceived)

var connCloseCode int
conn.SetCloseHandler(func(code int, text string) error {
connCloseCode = code
return nil
})

sock.Close(1001)

conn.ReadMessage() // read the close message

assert.Equal(t, 1001, serverCloseCode)
assert.Equal(t, 1001, connCloseCode)
}

func TestSocketClientCloseWithMessage(t *testing.T) {
var sock httpx.WebSocket
var serverCloseCode int

serverURL := newSocketServer(t, func(ws httpx.WebSocket) {
sock = ws
sock.OnClose(func(code int) { serverCloseCode = code })
sock.Start()
})

conn := newSocketConnection(t, serverURL)
conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.ClosePolicyViolation, ""))
conn.Close()

time.Sleep(250 * time.Millisecond)

assert.Equal(t, websocket.ClosePolicyViolation, serverCloseCode)
}

func TestSocketClientCloseWithoutMessage(t *testing.T) {
var sock httpx.WebSocket
var serverCloseCode int

serverURL := newSocketServer(t, func(ws httpx.WebSocket) {
sock = ws
sock.OnClose(func(code int) { serverCloseCode = code })
sock.Start()
})

conn := newSocketConnection(t, serverURL)
conn.Close()

time.Sleep(250 * time.Millisecond)

assert.Equal(t, websocket.CloseAbnormalClosure, serverCloseCode)
}

0 comments on commit 842f638

Please sign in to comment.