Skip to content

Commit

Permalink
Test custom websocket headers and origin
Browse files Browse the repository at this point in the history
Signed-off-by: Lorenzo Donini <lorenzo.donini@motius.de>
  • Loading branch information
lorenzodonini committed Dec 10, 2020
1 parent fdf0156 commit b2c5d05
Showing 1 changed file with 78 additions and 0 deletions.
78 changes: 78 additions & 0 deletions ws/websocket_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,84 @@ func TestInvalidBasicAuth(t *testing.T) {
wsServer.Stop()
}

func TestInvalidOriginHeader(t *testing.T) {
var wsServer *Server
wsServer = NewWebsocketServer(t, func(data []byte) ([]byte, error) {
assert.Fail(t, "no message should be received from client!")
return nil, nil
})
wsServer.SetNewClientHandler(func(ws Channel) {
assert.Fail(t, "no new connection should be received from client!")
})
go wsServer.Start(serverPort, serverPath)
time.Sleep(500 * time.Millisecond)

// Test message
wsClient := NewWebsocketClient(t, func(data []byte) ([]byte, error) {
assert.Fail(t, "no message should be received from server!")
return nil, nil
})
// Set invalid origin header
wsClient.SetHeaderValue("Origin", "example.org")
host := fmt.Sprintf("localhost:%v", serverPort)
u := url.URL{Scheme: "ws", Host: host, Path: testPath}
// Attempt to connect and expect cross-origin error
err := wsClient.Start(u.String())
require.Error(t, err)
httpErr, ok := err.(HttpConnectionError)
require.True(t, ok)
assert.Equal(t, http.StatusForbidden, httpErr.HttpCode)
assert.Equal(t, http.StatusForbidden, httpErr.HttpCode)
assert.Equal(t, "websocket: bad handshake", httpErr.Message)
// Cleanup
wsServer.Stop()
}

func TestCustomOriginHeaderHandler(t *testing.T) {
var wsServer *Server
origin := "example.org"
connected := make(chan bool)
wsServer = NewWebsocketServer(t, func(data []byte) ([]byte, error) {
assert.Fail(t, "no message should be received from client!")
return nil, nil
})
wsServer.SetNewClientHandler(func(ws Channel) {
connected <- true
})
wsServer.SetCheckOriginHandler(func(r *http.Request) bool {
return r.Header.Get("Origin") == origin
})
go wsServer.Start(serverPort, serverPath)
time.Sleep(500 * time.Millisecond)

// Test message
wsClient := NewWebsocketClient(t, func(data []byte) ([]byte, error) {
assert.Fail(t, "no message should be received from server!")
return nil, nil
})
// Set invalid origin header (not example.org)
wsClient.SetHeaderValue("Origin", "localhost")
host := fmt.Sprintf("localhost:%v", serverPort)
u := url.URL{Scheme: "ws", Host: host, Path: testPath}
// Attempt to connect and expect cross-origin error
err := wsClient.Start(u.String())
require.Error(t, err)
httpErr, ok := err.(HttpConnectionError)
require.True(t, ok)
assert.Equal(t, http.StatusForbidden, httpErr.HttpCode)
assert.Equal(t, http.StatusForbidden, httpErr.HttpCode)
assert.Equal(t, "websocket: bad handshake", httpErr.Message)

// Re-attempt with correct header
wsClient.SetHeaderValue("Origin", "example.org")
err = wsClient.Start(u.String())
require.NoError(t, err)
result := <-connected
assert.True(t, result)
// Cleanup
wsServer.Stop()
}

func TestValidClientTLSCertificate(t *testing.T) {
var wsServer *Server
// Create self-signed TLS certificate
Expand Down

0 comments on commit b2c5d05

Please sign in to comment.