-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
ac4221c
commit ee881ed
Showing
4 changed files
with
224 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,174 @@ | ||
package httpx | ||
|
||
import ( | ||
"net/http" | ||
"sync" | ||
"time" | ||
|
||
"github.com/gorilla/websocket" | ||
) | ||
|
||
const ( | ||
// max time for between reading a message before socket is considered closed | ||
maxReadWait = 60 * time.Second | ||
|
||
// maximum time to wait for message to be written | ||
maxWriteWait = 15 * time.Second | ||
|
||
// how often to send a ping message | ||
pingPeriod = 30 * time.Second | ||
) | ||
|
||
var upgrader = websocket.Upgrader{ | ||
ReadBufferSize: 1024, | ||
WriteBufferSize: 1024, | ||
} | ||
|
||
type WebSocket interface { | ||
Start() | ||
Send(msg []byte) | ||
Close() | ||
|
||
OnMessage(fn func([]byte)) | ||
OnClose(fn func(int)) | ||
} | ||
|
||
// Socket implemention using gorilla library | ||
type socket struct { | ||
conn *websocket.Conn | ||
onMessage func([]byte) | ||
onClose func(int) | ||
outbox chan []byte | ||
readError chan error | ||
writeError chan error | ||
stopWriter chan bool | ||
stopMonitor chan bool | ||
rwWaitGroup sync.WaitGroup | ||
monitorWaitGroup sync.WaitGroup | ||
} | ||
|
||
func NewWebSocket(w http.ResponseWriter, r *http.Request, maxReadBytes int64, sendBuffer int) (WebSocket, error) { | ||
conn, err := upgrader.Upgrade(w, r, nil) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
conn.SetReadLimit(maxReadBytes) | ||
|
||
return &socket{ | ||
conn: conn, | ||
onMessage: func([]byte) {}, | ||
onClose: func(int) {}, | ||
outbox: make(chan []byte, sendBuffer), | ||
readError: make(chan error, 1), | ||
writeError: make(chan error, 1), | ||
stopWriter: make(chan bool, 1), | ||
stopMonitor: make(chan bool, 1), | ||
}, nil | ||
} | ||
|
||
func (s *socket) OnMessage(fn func([]byte)) { s.onMessage = fn } | ||
func (s *socket) OnClose(fn func(int)) { s.onClose = fn } | ||
|
||
func (s *socket) Start() { | ||
s.conn.SetReadDeadline(time.Now().Add(maxReadWait)) | ||
s.conn.SetPongHandler(s.pong) | ||
|
||
go s.monitor() | ||
go s.reader() | ||
go s.writer() | ||
} | ||
|
||
func (s *socket) Send(msg []byte) { | ||
s.outbox <- msg | ||
} | ||
|
||
func (s *socket) Close() { | ||
s.conn.Close() // causes reader to stop | ||
s.stopWriter <- true | ||
s.stopMonitor <- true | ||
|
||
s.monitorWaitGroup.Wait() | ||
} | ||
|
||
func (s *socket) pong(m string) error { | ||
s.conn.SetReadDeadline(time.Now().Add(maxReadWait)) | ||
|
||
return nil | ||
} | ||
|
||
func (s *socket) monitor() { | ||
s.monitorWaitGroup.Add(1) | ||
defer s.monitorWaitGroup.Done() | ||
|
||
closeCode := websocket.CloseNormalClosure | ||
|
||
out: | ||
for { | ||
select { | ||
case err := <-s.readError: | ||
if e, ok := err.(*websocket.CloseError); ok { | ||
closeCode = e.Code | ||
} | ||
s.stopWriter <- true // ensure writer is stopped | ||
break out | ||
case err := <-s.writeError: | ||
if e, ok := err.(*websocket.CloseError); ok { | ||
closeCode = e.Code | ||
} | ||
s.conn.Close() // ensure reader is stopped | ||
break out | ||
case <-s.stopMonitor: | ||
break out | ||
} | ||
} | ||
|
||
s.rwWaitGroup.Wait() | ||
|
||
s.onClose(closeCode) | ||
} | ||
|
||
func (s *socket) reader() { | ||
s.rwWaitGroup.Add(1) | ||
defer s.rwWaitGroup.Done() | ||
|
||
for { | ||
_, message, err := s.conn.ReadMessage() | ||
if err != nil { | ||
s.readError <- err | ||
return | ||
} | ||
|
||
s.onMessage(message) | ||
} | ||
} | ||
|
||
func (s *socket) writer() { | ||
s.rwWaitGroup.Add(1) | ||
defer s.rwWaitGroup.Done() | ||
|
||
ticker := time.NewTicker(pingPeriod) | ||
defer ticker.Stop() | ||
|
||
for { | ||
select { | ||
case msg := <-s.outbox: | ||
s.conn.SetWriteDeadline(time.Now().Add(maxWriteWait)) | ||
|
||
err := s.conn.WriteMessage(websocket.TextMessage, msg) | ||
if err != nil { | ||
s.writeError <- err | ||
return | ||
} | ||
case <-ticker.C: | ||
s.conn.SetWriteDeadline(time.Now().Add(maxWriteWait)) | ||
|
||
if err := s.conn.WriteMessage(websocket.PingMessage, nil); err != nil { | ||
s.writeError <- err | ||
return | ||
} | ||
case <-s.stopWriter: | ||
return | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
package httpx_test | ||
|
||
import ( | ||
"net/http" | ||
"net/http/httptest" | ||
"strings" | ||
"testing" | ||
"time" | ||
|
||
"github.com/gorilla/websocket" | ||
"github.com/nyaruka/gocommon/httpx" | ||
"github.com/stretchr/testify/assert" | ||
"github.com/stretchr/testify/require" | ||
) | ||
|
||
func TestSocket(t *testing.T) { | ||
var sock httpx.WebSocket | ||
var err error | ||
|
||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||
sock, err = httpx.NewWebSocket(w, r, 4096, 5) | ||
sock.Start() | ||
|
||
require.NoError(t, err) | ||
})) | ||
|
||
wsURL := "ws:" + strings.TrimPrefix(server.URL, "http:") | ||
|
||
d := websocket.Dialer{ | ||
Subprotocols: []string{"p1", "p2"}, | ||
ReadBufferSize: 1024, | ||
WriteBufferSize: 1024, | ||
HandshakeTimeout: 30 * time.Second, | ||
} | ||
conn, _, err := d.Dial(wsURL, nil) | ||
assert.NoError(t, err) | ||
assert.NotNil(t, conn) | ||
|
||
sock.Send([]byte("test")) | ||
|
||
msgType, msg, err := conn.ReadMessage() | ||
assert.NoError(t, err) | ||
assert.Equal(t, 1, msgType) | ||
assert.Equal(t, "test", string(msg)) | ||
|
||
sock.Close() | ||
} |