-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #108 from nyaruka/websockets
Add WebSocket functionality to httpx
- Loading branch information
Showing
4 changed files
with
320 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,177 @@ | ||
package httpx | ||
|
||
import ( | ||
"net/http" | ||
"sync" | ||
"time" | ||
|
||
"github.com/gorilla/websocket" | ||
) | ||
|
||
const ( | ||
// max time for between reading a message before socket is considered closed | ||
maxReadWait = 60 * time.Second | ||
|
||
// maximum time to wait for message to be written | ||
maxWriteWait = 15 * time.Second | ||
|
||
// how often to send a ping message | ||
pingPeriod = 30 * time.Second | ||
) | ||
|
||
var upgrader = websocket.Upgrader{ReadBufferSize: 1024, WriteBufferSize: 1024} | ||
|
||
// WebSocket provides a websocket interface similar to that of Javascript. | ||
type WebSocket interface { | ||
// Start begins reading and writing of messages on this socket | ||
Start() | ||
|
||
// Send sends the given message over the socket | ||
Send([]byte) | ||
|
||
// Close closes the socket connection | ||
Close(int) | ||
|
||
// OnMessage is called when the socket receives a message | ||
OnMessage(func([]byte)) | ||
|
||
// OnClose is called when the socket is closed (even if we initiate the close) | ||
OnClose(func(int)) | ||
} | ||
|
||
// WebSocket implemention using gorilla library | ||
type socket struct { | ||
conn *websocket.Conn | ||
onMessage func([]byte) | ||
onClose func(int) | ||
outbox chan []byte | ||
readError chan error | ||
writeError chan error | ||
stopWriter chan bool | ||
closingWithCode int | ||
rwWaitGroup sync.WaitGroup | ||
monitorWaitGroup sync.WaitGroup | ||
} | ||
|
||
// NewWebSocket creates a new web socket from a regular HTTP request | ||
func NewWebSocket(w http.ResponseWriter, r *http.Request, maxReadBytes int64, sendBuffer int) (WebSocket, error) { | ||
conn, err := upgrader.Upgrade(w, r, nil) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
conn.SetReadLimit(maxReadBytes) | ||
|
||
return &socket{ | ||
conn: conn, | ||
onMessage: func([]byte) {}, | ||
onClose: func(int) {}, | ||
outbox: make(chan []byte, sendBuffer), | ||
readError: make(chan error, 1), | ||
writeError: make(chan error, 1), | ||
stopWriter: make(chan bool, 1), | ||
}, nil | ||
} | ||
|
||
func (s *socket) OnMessage(fn func([]byte)) { s.onMessage = fn } | ||
func (s *socket) OnClose(fn func(int)) { s.onClose = fn } | ||
|
||
func (s *socket) Start() { | ||
s.conn.SetReadDeadline(time.Now().Add(maxReadWait)) | ||
s.conn.SetPongHandler(s.pong) | ||
|
||
go s.monitor() | ||
go s.reader() | ||
go s.writer() | ||
} | ||
|
||
func (s *socket) Send(msg []byte) { | ||
s.outbox <- msg | ||
} | ||
|
||
func (s *socket) Close(code int) { | ||
s.closingWithCode = code | ||
s.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(code, "")) | ||
s.conn.Close() // causes reader to stop | ||
s.stopWriter <- true | ||
|
||
s.monitorWaitGroup.Wait() | ||
} | ||
|
||
func (s *socket) pong(m string) error { | ||
s.conn.SetReadDeadline(time.Now().Add(maxReadWait)) | ||
|
||
return nil | ||
} | ||
|
||
func (s *socket) monitor() { | ||
s.monitorWaitGroup.Add(1) | ||
defer s.monitorWaitGroup.Done() | ||
|
||
out: | ||
for { | ||
select { | ||
case err := <-s.readError: | ||
if e, ok := err.(*websocket.CloseError); ok && s.closingWithCode == 0 { | ||
s.closingWithCode = e.Code | ||
} | ||
s.stopWriter <- true // ensure writer is stopped | ||
break out | ||
case err := <-s.writeError: | ||
if e, ok := err.(*websocket.CloseError); ok { | ||
s.closingWithCode = e.Code | ||
} | ||
s.conn.Close() // ensure reader is stopped | ||
break out | ||
} | ||
} | ||
|
||
s.rwWaitGroup.Wait() | ||
|
||
s.onClose(s.closingWithCode) | ||
} | ||
|
||
func (s *socket) reader() { | ||
s.rwWaitGroup.Add(1) | ||
defer s.rwWaitGroup.Done() | ||
|
||
for { | ||
_, message, err := s.conn.ReadMessage() | ||
if err != nil { | ||
s.readError <- err | ||
return | ||
} | ||
|
||
s.onMessage(message) | ||
} | ||
} | ||
|
||
func (s *socket) writer() { | ||
s.rwWaitGroup.Add(1) | ||
defer s.rwWaitGroup.Done() | ||
|
||
ticker := time.NewTicker(pingPeriod) | ||
defer ticker.Stop() | ||
|
||
for { | ||
select { | ||
case msg := <-s.outbox: | ||
s.conn.SetWriteDeadline(time.Now().Add(maxWriteWait)) | ||
|
||
err := s.conn.WriteMessage(websocket.TextMessage, msg) | ||
if err != nil { | ||
s.writeError <- err | ||
return | ||
} | ||
case <-ticker.C: | ||
s.conn.SetWriteDeadline(time.Now().Add(maxWriteWait)) | ||
|
||
if err := s.conn.WriteMessage(websocket.PingMessage, nil); err != nil { | ||
s.writeError <- err | ||
return | ||
} | ||
case <-s.stopWriter: | ||
return | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
package httpx_test | ||
|
||
import ( | ||
"net/http" | ||
"net/http/httptest" | ||
"strings" | ||
"testing" | ||
"time" | ||
|
||
"github.com/gorilla/websocket" | ||
"github.com/nyaruka/gocommon/httpx" | ||
"github.com/stretchr/testify/assert" | ||
"github.com/stretchr/testify/require" | ||
) | ||
|
||
func newSocketServer(t *testing.T, fn func(httpx.WebSocket)) string { | ||
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||
sock, err := httpx.NewWebSocket(w, r, 4096, 5) | ||
require.NoError(t, err) | ||
|
||
fn(sock) | ||
})) | ||
|
||
return "ws:" + strings.TrimPrefix(s.URL, "http:") | ||
} | ||
|
||
func newSocketConnection(t *testing.T, url string) *websocket.Conn { | ||
d := websocket.Dialer{ | ||
Subprotocols: []string{"p1", "p2"}, | ||
ReadBufferSize: 1024, | ||
WriteBufferSize: 1024, | ||
HandshakeTimeout: 30 * time.Second, | ||
} | ||
c, _, err := d.Dial(url, nil) | ||
assert.NoError(t, err) | ||
return c | ||
} | ||
|
||
func TestSocketMessages(t *testing.T) { | ||
var sock httpx.WebSocket | ||
var serverReceived [][]byte | ||
var serverCloseCode int | ||
|
||
serverURL := newSocketServer(t, func(ws httpx.WebSocket) { | ||
sock = ws | ||
sock.OnMessage(func(b []byte) { | ||
serverReceived = append(serverReceived, b) | ||
}) | ||
sock.OnClose(func(code int) { | ||
serverCloseCode = code | ||
}) | ||
sock.Start() | ||
}) | ||
|
||
conn := newSocketConnection(t, serverURL) | ||
|
||
// send a message from the server... | ||
sock.Send([]byte("from server")) | ||
|
||
// and read it from the client | ||
msgType, msg, err := conn.ReadMessage() | ||
assert.NoError(t, err) | ||
assert.Equal(t, 1, msgType) | ||
assert.Equal(t, "from server", string(msg)) | ||
|
||
// send a message from the client... | ||
conn.WriteMessage(websocket.TextMessage, []byte("to server")) | ||
|
||
// and check server received it | ||
time.Sleep(500 * time.Millisecond) | ||
assert.Equal(t, [][]byte{[]byte("to server")}, serverReceived) | ||
|
||
pongReceived := false | ||
conn.SetPongHandler(func(appData string) error { | ||
pongReceived = true | ||
return nil | ||
}) | ||
|
||
// send a ping message from the client... | ||
conn.WriteMessage(websocket.PingMessage, []byte{}) | ||
|
||
// and give server time to receive it and respond | ||
time.Sleep(500 * time.Millisecond) | ||
|
||
// give the connection something to read because ReadMessage will block until it gets a non-ping-pong message | ||
sock.Send([]byte("dummy")) | ||
conn.ReadMessage() | ||
|
||
assert.True(t, pongReceived) | ||
|
||
var connCloseCode int | ||
conn.SetCloseHandler(func(code int, text string) error { | ||
connCloseCode = code | ||
return nil | ||
}) | ||
|
||
sock.Close(1001) | ||
|
||
conn.ReadMessage() // read the close message | ||
|
||
assert.Equal(t, 1001, serverCloseCode) | ||
assert.Equal(t, 1001, connCloseCode) | ||
} | ||
|
||
func TestSocketClientCloseWithMessage(t *testing.T) { | ||
var sock httpx.WebSocket | ||
var serverCloseCode int | ||
|
||
serverURL := newSocketServer(t, func(ws httpx.WebSocket) { | ||
sock = ws | ||
sock.OnClose(func(code int) { serverCloseCode = code }) | ||
sock.Start() | ||
}) | ||
|
||
conn := newSocketConnection(t, serverURL) | ||
conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.ClosePolicyViolation, "")) | ||
conn.Close() | ||
|
||
time.Sleep(250 * time.Millisecond) | ||
|
||
assert.Equal(t, websocket.ClosePolicyViolation, serverCloseCode) | ||
} | ||
|
||
func TestSocketClientCloseWithoutMessage(t *testing.T) { | ||
var sock httpx.WebSocket | ||
var serverCloseCode int | ||
|
||
serverURL := newSocketServer(t, func(ws httpx.WebSocket) { | ||
sock = ws | ||
sock.OnClose(func(code int) { serverCloseCode = code }) | ||
sock.Start() | ||
}) | ||
|
||
conn := newSocketConnection(t, serverURL) | ||
conn.Close() | ||
|
||
time.Sleep(250 * time.Millisecond) | ||
|
||
assert.Equal(t, websocket.CloseAbnormalClosure, serverCloseCode) | ||
} |