From 19cd3e6b993dcc8041df495034e8b9813cb443e8 Mon Sep 17 00:00:00 2001 From: Jan Jansen Date: Fri, 10 Jan 2025 09:24:13 +0100 Subject: [PATCH] backend: enhance WSConnLock with thread-safety The WSConnLock struct now provides a robust wrapper around WebSocket connections, ensuring thread-safe write operations while maintaining efficient read access. This helps prevent race conditions in concurrent WebSocket operations throughout the multiplexer. Co-Authored-by: farodin91 Signed-off-by: Kautilya Tripathi --- backend/cmd/multiplexer.go | 88 ++++- backend/cmd/multiplexer_test.go | 675 +++++++++++++++++++------------- 2 files changed, 473 insertions(+), 290 deletions(-) diff --git a/backend/cmd/multiplexer.go b/backend/cmd/multiplexer.go index e83236821e..06cefc31fa 100644 --- a/backend/cmd/multiplexer.go +++ b/backend/cmd/multiplexer.go @@ -63,7 +63,7 @@ type Connection struct { // Status is the status of the connection. Status ConnectionStatus // Client is the WebSocket connection to the client. - Client *websocket.Conn + Client *WSConnLock // Done is a channel to signal when the connection is done. Done chan struct{} // mu is a mutex to synchronize access to the connection. @@ -104,6 +104,68 @@ type Multiplexer struct { kubeConfigStore kubeconfig.ContextStore } +// WSConnLock provides a thread-safe wrapper around a WebSocket connection. +// It ensures that write operations are synchronized using a mutex to prevent +// concurrent writes which could corrupt the WebSocket stream. +type WSConnLock struct { + // conn is the underlying WebSocket connection + conn *websocket.Conn + // writeMu is a mutex to synchronize access to write operations. + // This prevents concurrent writes to the WebSocket connection. + writeMu sync.Mutex +} + +// NewWSConnLock creates a new WSConnLock instance that wraps the provided +// WebSocket connection with thread-safe write operations. +func NewWSConnLock(conn *websocket.Conn) *WSConnLock { + return &WSConnLock{ + conn: conn, + writeMu: sync.Mutex{}, + } +} + +// WriteJSON writes the JSON encoding of v as a message to the WebSocket connection. +// It ensures thread-safety by using a mutex lock during the write operation. +func (conn *WSConnLock) WriteJSON(v interface{}) error { + conn.writeMu.Lock() + defer conn.writeMu.Unlock() + + return conn.conn.WriteJSON(v) +} + +// ReadJSON reads the next JSON-encoded message from the WebSocket connection +// and stores it in the value pointed to by v. +// Note: Reading is already thread-safe in gorilla/websocket, so no mutex is needed. +func (conn *WSConnLock) ReadJSON(v interface{}) error { + return conn.conn.ReadJSON(v) +} + +// ReadMessage reads the next message from the WebSocket connection. +// It returns the message type and payload. +// Note: Reading is already thread-safe in gorilla/websocket, so no mutex is needed. +func (conn *WSConnLock) ReadMessage() (messageType int, p []byte, err error) { + return conn.conn.ReadMessage() +} + +// WriteMessage writes a message to the WebSocket connection with the given type and payload. +// It ensures thread-safety by using a mutex lock during the write operation. +func (conn *WSConnLock) WriteMessage(messageType int, data []byte) error { + conn.writeMu.Lock() + defer conn.writeMu.Unlock() + + return conn.conn.WriteMessage(messageType, data) +} + +// Close safely closes the WebSocket connection. +// It ensures thread-safety by acquiring the write mutex before closing, +// preventing any concurrent writes during the close operation. +func (conn *WSConnLock) Close() error { + conn.writeMu.Lock() + defer conn.writeMu.Unlock() + + return conn.conn.Close() +} + // NewMultiplexer creates a new Multiplexer instance. func NewMultiplexer(kubeConfigStore kubeconfig.ContextStore) *Multiplexer { return &Multiplexer{ @@ -183,7 +245,7 @@ func (m *Multiplexer) establishClusterConnection( userID, path, query string, - clientConn *websocket.Conn, + clientConn *WSConnLock, ) (*Connection, error) { config, err := m.getClusterConfigWithFallback(clusterID, userID) if err != nil { @@ -246,7 +308,7 @@ func (m *Multiplexer) createConnection( userID, path, query string, - clientConn *websocket.Conn, + clientConn *WSConnLock, ) *Connection { return &Connection{ ClusterID: clusterID, @@ -355,6 +417,8 @@ func (m *Multiplexer) HandleClientWebSocket(w http.ResponseWriter, r *http.Reque defer clientConn.Close() + lockClientConn := NewWSConnLock(clientConn) + for { msg, err := m.readClientMessage(clientConn) if err != nil { @@ -368,9 +432,9 @@ func (m *Multiplexer) HandleClientWebSocket(w http.ResponseWriter, r *http.Reque continue } - conn, err := m.getOrCreateConnection(msg, clientConn) + conn, err := m.getOrCreateConnection(msg, lockClientConn) if err != nil { - m.handleConnectionError(clientConn, msg, err) + m.handleConnectionError(lockClientConn, msg, err) continue } @@ -408,7 +472,7 @@ func (m *Multiplexer) readClientMessage(clientConn *websocket.Conn) (Message, er } // getOrCreateConnection gets an existing connection or creates a new one if it doesn't exist. -func (m *Multiplexer) getOrCreateConnection(msg Message, clientConn *websocket.Conn) (*Connection, error) { +func (m *Multiplexer) getOrCreateConnection(msg Message, clientConn *WSConnLock) (*Connection, error) { connKey := m.createConnectionKey(msg.ClusterID, msg.Path, msg.UserID) m.mutex.RLock() @@ -437,7 +501,7 @@ func (m *Multiplexer) getOrCreateConnection(msg Message, clientConn *websocket.C } // handleConnectionError handles errors that occur when establishing a connection. -func (m *Multiplexer) handleConnectionError(clientConn *websocket.Conn, msg Message, err error) { +func (m *Multiplexer) handleConnectionError(clientConn *WSConnLock, msg Message, err error) { errorMsg := struct { ClusterID string `json:"clusterId"` Error string `json:"error"` @@ -477,7 +541,7 @@ func (m *Multiplexer) writeMessageToCluster(conn *Connection, data []byte) error } // handleClusterMessages handles messages from a cluster connection. -func (m *Multiplexer) handleClusterMessages(conn *Connection, clientConn *websocket.Conn) { +func (m *Multiplexer) handleClusterMessages(conn *Connection, clientConn *WSConnLock) { defer m.cleanupConnection(conn) var lastResourceVersion string @@ -497,7 +561,7 @@ func (m *Multiplexer) handleClusterMessages(conn *Connection, clientConn *websoc // processClusterMessage processes a single message from the cluster. func (m *Multiplexer) processClusterMessage( conn *Connection, - clientConn *websocket.Conn, + clientConn *WSConnLock, lastResourceVersion *string, ) error { messageType, message, err := conn.WSConn.ReadMessage() @@ -541,7 +605,7 @@ func (m *Multiplexer) processClusterMessage( func (m *Multiplexer) sendIfNewResourceVersion( message []byte, conn *Connection, - clientConn *websocket.Conn, + clientConn *WSConnLock, lastResourceVersion *string, ) error { var obj map[string]interface{} @@ -581,7 +645,7 @@ func (m *Multiplexer) sendIfNewResourceVersion( } // sendCompleteMessage sends a COMPLETE message to the client. -func (m *Multiplexer) sendCompleteMessage(conn *Connection, clientConn *websocket.Conn) error { +func (m *Multiplexer) sendCompleteMessage(conn *Connection, clientConn *WSConnLock) error { conn.mu.RLock() if conn.closed { conn.mu.RUnlock() @@ -614,7 +678,7 @@ func (m *Multiplexer) sendCompleteMessage(conn *Connection, clientConn *websocke // sendDataMessage sends the actual data message to the client. func (m *Multiplexer) sendDataMessage( conn *Connection, - clientConn *websocket.Conn, + clientConn *WSConnLock, messageType int, message []byte, ) error { diff --git a/backend/cmd/multiplexer_test.go b/backend/cmd/multiplexer_test.go index 1044d9111e..078ab89545 100644 --- a/backend/cmd/multiplexer_test.go +++ b/backend/cmd/multiplexer_test.go @@ -2,6 +2,7 @@ package main import ( "crypto/tls" + "encoding/base64" "encoding/json" "fmt" "net" @@ -60,6 +61,8 @@ func TestHandleClientWebSocket(t *testing.T) { defer ws.Close() + wsConn := NewWSConnLock(ws) + // Test WATCH message watchMsg := Message{ Type: "WATCH", @@ -67,7 +70,7 @@ func TestHandleClientWebSocket(t *testing.T) { Path: "/api/v1/pods", UserID: "test-user", } - err = ws.WriteJSON(watchMsg) + err = wsConn.WriteJSON(watchMsg) require.NoError(t, err) // Test CLOSE message @@ -77,7 +80,7 @@ func TestHandleClientWebSocket(t *testing.T) { Path: "/api/v1/pods", UserID: "test-user", } - err = ws.WriteJSON(closeMsg) + err = wsConn.WriteJSON(closeMsg) require.NoError(t, err) } @@ -177,11 +180,16 @@ func TestDialWebSocket_Errors(t *testing.T) { func TestMonitorConnection(t *testing.T) { m := NewMultiplexer(kubeconfig.NewContextStore()) - clientConn, _ := createTestWebSocketConnection() + clientConn, clientServer := createTestWebSocketConnection() - // Updated createConnection call with all required arguments - conn := m.createConnection("test-cluster", "test-user", "/api/v1/pods", "", clientConn) - conn.WSConn, _ = createTestWebSocketConnection() + defer clientServer.Close() + + conn := createTestConnection("test-cluster", "test-user", "/api/v1/pods", "", clientConn) + + wsConn, wsServer := createTestWebSocketConnection() + defer wsServer.Close() + + conn.WSConn = wsConn.conn done := make(chan struct{}) go func() { @@ -197,11 +205,10 @@ func TestMonitorConnection(t *testing.T) { } func TestUpdateStatus(t *testing.T) { - conn := &Connection{ - Status: ConnectionStatus{}, - Done: make(chan struct{}), - mu: sync.RWMutex{}, - } + clientConn, clientServer := createTestWebSocketConnection() + defer clientServer.Close() + + conn := createTestConnection("test-cluster", "test-user", "/api/v1/pods", "", clientConn) // Test different state transitions states := []ConnectionState{ @@ -212,9 +219,7 @@ func TestUpdateStatus(t *testing.T) { } for _, state := range states { - conn.mu.Lock() - conn.Status.State = state - conn.mu.Unlock() + conn.updateStatus(state, nil) assert.Equal(t, state, conn.Status.State) } @@ -225,25 +230,82 @@ func TestUpdateStatus(t *testing.T) { go func(i int) { defer wg.Done() - conn.mu.Lock() + state := states[i%len(states)] - conn.Status.State = state - conn.mu.Unlock() + conn.updateStatus(state, nil) }(i) } wg.Wait() // Verify final state is valid - conn.mu.RLock() assert.Contains(t, states, conn.Status.State) - conn.mu.RUnlock() } -func TestMonitorConnection_Reconnect(t *testing.T) { - contextStore := kubeconfig.NewContextStore() - m := NewMultiplexer(contextStore) +func TestCleanupConnections(t *testing.T) { + m := NewMultiplexer(kubeconfig.NewContextStore()) + clientConn, clientServer := createTestWebSocketConnection() - // Create a server that will accept the connection and then close it + defer clientServer.Close() + + wsConn, wsServer := createTestWebSocketConnection() + defer wsServer.Close() + + conn := createTestConnection("test-cluster", "test-user", "/api/v1/pods", "", clientConn) + conn.WSConn = wsConn.conn + + connKey := m.createConnectionKey("test-cluster", "/api/v1/pods", "test-user") + m.connections[connKey] = conn + + m.cleanupConnections() + + assert.Empty(t, m.connections) + assert.Equal(t, StateClosed, conn.Status.State) +} + +func TestCloseConnection(t *testing.T) { + m := NewMultiplexer(kubeconfig.NewContextStore()) + clientConn, clientServer := createTestWebSocketConnection() + + defer clientServer.Close() + + wsConn, wsServer := createTestWebSocketConnection() + defer wsServer.Close() + + conn := createTestConnection("test-cluster-1", "test-user", "/api/v1/pods", "", clientConn) + conn.WSConn = wsConn.conn + + connKey := m.createConnectionKey("test-cluster-1", "/api/v1/pods", "test-user") + m.connections[connKey] = conn + + m.CloseConnection("test-cluster-1", "/api/v1/pods", "test-user") + assert.Empty(t, m.connections) + assert.True(t, conn.closed) +} + +func createTestConnection( + clusterID, + userID, + path, + query string, + client *WSConnLock, +) *Connection { + return &Connection{ + ClusterID: clusterID, + UserID: userID, + Path: path, + Query: query, + Client: client, + Done: make(chan struct{}), + Status: ConnectionStatus{ + State: StateConnecting, + LastMsg: time.Now(), + }, + mu: sync.RWMutex{}, + writeMu: sync.Mutex{}, + } +} + +func createTestWebSocketConn() (*websocket.Conn, *httptest.Server) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { upgrader := websocket.Upgrader{ CheckOrigin: func(r *http.Request) bool { @@ -251,169 +313,146 @@ func TestMonitorConnection_Reconnect(t *testing.T) { }, } ws, err := upgrader.Upgrade(w, r, nil) - require.NoError(t, err) + if err != nil { + return + } defer ws.Close() - // Keep connection alive briefly - time.Sleep(100 * time.Millisecond) - ws.Close() + for { + messageType, message, err := ws.ReadMessage() + if err != nil { + break + } + err = ws.WriteMessage(messageType, message) + if err != nil { + break + } + } })) - defer server.Close() - - conn := &Connection{ - Status: ConnectionStatus{ - State: StateConnecting, - }, - Done: make(chan struct{}), - } - wsURL := "ws" + strings.TrimPrefix(server.URL, "http") - tlsConfig := &tls.Config{InsecureSkipVerify: true} //nolint:gosec - - ws, err := m.dialWebSocket(wsURL, tlsConfig, "") - require.NoError(t, err) - - conn.WSConn = ws + dialer := newTestDialer() - // Start monitoring in a goroutine - go m.monitorConnection(conn) + conn, resp, err := dialer.Dial(wsURL, nil) + if err != nil { + server.Close() - // Wait for state transitions - time.Sleep(300 * time.Millisecond) + return nil, nil + } - // Verify connection status, it should reconnect - assert.Equal(t, StateConnecting, conn.Status.State) + if resp != nil && resp.Body != nil { + defer resp.Body.Close() + } - // Clean up - close(conn.Done) + return conn, server } -//nolint:funlen -func TestHandleClusterMessages(t *testing.T) { - m := NewMultiplexer(kubeconfig.NewContextStore()) - clientConn, clientServer := createTestWebSocketConnection() - - defer clientServer.Close() +func createTestWebSocketConnection() (*WSConnLock, *httptest.Server) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + upgrader := websocket.Upgrader{ + CheckOrigin: func(r *http.Request) bool { + return true + }, + } - clusterConn, clusterServer := createTestWebSocketConnection() + ws, err := upgrader.Upgrade(w, r, nil) + if err != nil { + return + } - defer clusterServer.Close() + // Echo back any messages received + go func() { + for { + messageType, message, err := ws.ReadMessage() + if err != nil { + break + } + + if err := ws.WriteMessage(messageType, message); err != nil { + break + } + } + }() + })) - // Add RequestID to the createConnection call - conn := m.createConnection("minikube", "test-user", "/api/v1/pods", "watch=true", clientConn) - conn.WSConn = clusterConn + wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + dialer := newTestDialer() - done := make(chan struct{}) - go func() { - m.handleClusterMessages(conn, clientConn) - close(done) - }() + conn, resp, err := dialer.Dial(wsURL, nil) + if err != nil { + server.Close() + return nil, nil + } - // Send a test message from the cluster - testMessage := []byte(`{"kind":"Pod","apiVersion":"v1","metadata":{"name":"test-pod"}}`) - err := clusterConn.WriteMessage(websocket.TextMessage, testMessage) - require.NoError(t, err) + if resp != nil && resp.Body != nil { + defer resp.Body.Close() + } - // Read the message from the client connection - _, receivedMessage, err := clientConn.ReadMessage() - require.NoError(t, err) + return NewWSConnLock(conn), server +} - var wrapperMsg struct { - ClusterID string `json:"clusterId"` - Path string `json:"path"` - Query string `json:"query"` - UserID string `json:"userId"` - RequestID string `json:"requestId"` - Data string `json:"data"` - Binary bool `json:"binary"` - } +func TestWSConnLock(t *testing.T) { + wsConn, server := createTestWebSocketConnection() + defer server.Close() - err = json.Unmarshal(receivedMessage, &wrapperMsg) - require.NoError(t, err) + // Test concurrent writes + var wg sync.WaitGroup - assert.Equal(t, "minikube", wrapperMsg.ClusterID) - assert.Equal(t, "/api/v1/pods", wrapperMsg.Path) - assert.Equal(t, "watch=true", wrapperMsg.Query) - assert.Equal(t, "test-user", wrapperMsg.UserID) - assert.False(t, wrapperMsg.Binary) + for i := 0; i < 10; i++ { + wg.Add(1) - // Parse the Data field separately - var podData map[string]interface{} - err = json.Unmarshal([]byte(wrapperMsg.Data), &podData) - require.NoError(t, err) - assert.Equal(t, "Pod", podData["kind"]) - assert.Equal(t, "v1", podData["apiVersion"]) + go func(i int) { + defer wg.Done() - // Close the connection to trigger the end of handleClusterMessages - conn.WSConn.Close() + msg := fmt.Sprintf("message-%d", i) - // Wait for handleClusterMessages to finish - select { - case <-done: - // Function completed successfully - case <-time.After(5 * time.Second): - t.Fatal("Test timed out") + err := wsConn.WriteJSON(msg) + assert.NoError(t, err) + }(i) } - assert.Equal(t, StateConnecting, conn.Status.State) -} + wg.Wait() -func TestCleanupConnections(t *testing.T) { - m := NewMultiplexer(kubeconfig.NewContextStore()) - clientConn, _ := createTestWebSocketConnection() - // Add RequestID to the createConnection call - conn := m.createConnection("test-cluster", "test-user", "/api/v1/pods", "", clientConn) - conn.WSConn, _ = createTestWebSocketConnection() + // Test ReadJSON + var msg string + err := wsConn.ReadJSON(&msg) + assert.NoError(t, err) + assert.Contains(t, msg, "message-") +} - // Use the new connection key format - connKey := "test-cluster:/api/v1/pods:test-request-id" - m.connections[connKey] = conn +func createMockKubeAPIServer() *httptest.Server { + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + upgrader := websocket.Upgrader{ + CheckOrigin: func(r *http.Request) bool { + return true + }, + } + c, err := upgrader.Upgrade(w, r, nil) + if err != nil { + return + } - m.cleanupConnections() + defer c.Close() - assert.Empty(t, m.connections) - assert.Equal(t, StateClosed, conn.Status.State) -} + // Echo messages back + for { + _, msg, err := c.ReadMessage() + if err != nil { + break + } + if err := c.WriteMessage(websocket.TextMessage, msg); err != nil { + break + } + } + })) -func TestCreateWebSocketURL(t *testing.T) { - tests := []struct { - name string - host string - path string - query string - expected string - }{ - { - name: "basic URL without query", - host: "http://localhost:8080", - path: "/api/v1/pods", - query: "", - expected: "wss://localhost:8080/api/v1/pods", - }, - { - name: "URL with query parameters", - host: "https://example.com", - path: "/api/v1/pods", - query: "watch=true", - expected: "wss://example.com/api/v1/pods?watch=true", - }, - { - name: "URL with path and multiple query parameters", - host: "https://k8s.example.com", - path: "/api/v1/namespaces/default/pods", - query: "watch=true&labelSelector=app%3Dnginx", - expected: "wss://k8s.example.com/api/v1/namespaces/default/pods?watch=true&labelSelector=app%3Dnginx", - }, + // Configure the test client to accept the test server's TLS certificate + server.Client().Transport.(*http.Transport).TLSClientConfig = &tls.Config{ + InsecureSkipVerify: true, //nolint:gosec } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := createWebSocketURL(tt.host, tt.path, tt.query) - assert.Equal(t, tt.expected, result) - }) - } + return server } func TestGetOrCreateConnection(t *testing.T) { @@ -527,7 +566,12 @@ func TestReconnect(t *testing.T) { defer clientServer.Close() // Create initial connection - conn := m.createConnection("test-cluster", "test-user", "/api/v1/pods", "watch=true", clientConn) + conn := m.createConnection("test-cluster", "test-user", "/api/v1/services", "watch=true", clientConn) + wsConn, wsServer := createTestWebSocketConnection() + + defer wsServer.Close() + + conn.WSConn = wsConn.conn conn.Status.State = StateError // Simulate an error state // Test successful reconnection @@ -549,21 +593,14 @@ func TestReconnect(t *testing.T) { // Test reconnection with closed connection conn = m.createConnection("test-cluster", "test-user", "/api/v1/pods", "watch=true", clientConn) - clusterConn, err := m.establishClusterConnection("test-cluster", "test-user", "/api/v1/pods", "watch=true", clientConn) - require.NoError(t, err) - require.NotNil(t, clusterConn) + wsConn2, wsServer2 := createTestWebSocketConnection() - // Close the connection and wait for cleanup - conn.closed = true - if conn.WSConn != nil { - conn.WSConn.Close() - } + defer wsServer2.Close() - if conn.Client != nil { - conn.Client.Close() - } + conn.WSConn = wsConn2.conn - close(conn.Done) + // Close the connection and wait for cleanup + conn.closed = true // Mark connection as closed // Try to reconnect the closed connection newConn, err = m.reconnect(conn) @@ -571,21 +608,6 @@ func TestReconnect(t *testing.T) { assert.Nil(t, newConn) } -func TestCloseConnection(t *testing.T) { - m := NewMultiplexer(kubeconfig.NewContextStore()) - clientConn, _ := createTestWebSocketConnection() - conn := m.createConnection("test-cluster", "test-user", "/api/v1/pods", "", clientConn) - conn.WSConn, _ = createTestWebSocketConnection() - - connKey := "test-cluster:/api/v1/pods:test-user" - m.connections[connKey] = conn - - m.CloseConnection("test-cluster", "/api/v1/pods", "test-user") - assert.Empty(t, m.connections) - // It will reconnect to the cluster - assert.Equal(t, StateConnecting, conn.Status.State) -} - func TestCreateWrapperMessage(t *testing.T) { m := NewMultiplexer(kubeconfig.NewContextStore()) conn := &Connection{ @@ -668,47 +690,6 @@ func TestHandleConnectionError(t *testing.T) { } } -func TestWriteMessageToCluster(t *testing.T) { - m := NewMultiplexer(kubeconfig.NewContextStore()) - clusterConn, clusterServer := createTestWebSocketConnection() - - defer clusterServer.Close() - - conn := &Connection{ - ClusterID: "test-cluster", - WSConn: clusterConn, - } - - testMessage := []byte("Hello, Cluster!") - - // Capture the message sent to the cluster - var receivedMessage []byte - - done := make(chan bool) - go func() { - _, receivedMessage, _ = clusterConn.ReadMessage() - done <- true - }() - - err := m.writeMessageToCluster(conn, testMessage) - assert.NoError(t, err) - - select { - case <-done: - assert.Equal(t, testMessage, receivedMessage) - case <-time.After(time.Second): - t.Fatal("Test timed out") - } - - // Test error case - clusterConn.Close() - - err = m.writeMessageToCluster(conn, testMessage) - - assert.Error(t, err) - assert.Equal(t, StateError, conn.Status.State) -} - //nolint:funlen func TestReadClientMessage_InvalidMessage(t *testing.T) { contextStore := kubeconfig.NewContextStore() @@ -838,9 +819,15 @@ func TestMonitorConnection_ReconnectFailure(t *testing.T) { }) require.NoError(t, err) - clientConn, _ := createTestWebSocketConnection() + clientConn, clientServer := createTestWebSocketConnection() + defer clientServer.Close() + conn := m.createConnection("test-cluster", "test-user", "/api/v1/pods", "", clientConn) - conn.WSConn, _ = createTestWebSocketConnection() + wsConn, wsServer := createTestWebSocketConn() + + defer wsServer.Close() + + conn.WSConn = wsConn // Start monitoring done := make(chan struct{}) @@ -864,12 +851,14 @@ func TestMonitorConnection_ReconnectFailure(t *testing.T) { <-done } +//nolint:funlen func TestHandleClientWebSocket_InvalidMessages(t *testing.T) { m := NewMultiplexer(kubeconfig.NewContextStore()) server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { m.HandleClientWebSocket(w, r) })) + defer server.Close() // Test invalid JSON @@ -913,6 +902,7 @@ func TestHandleClientWebSocket_InvalidMessages(t *testing.T) { Path: "/api/v1/pods", UserID: "test-user", }) + require.NoError(t, err) // Should receive an error message or close @@ -1004,7 +994,7 @@ func TestSendCompleteMessage_ClosedConnection(t *testing.T) { assert.Equal(t, conn.Query, msg.Query) assert.Equal(t, conn.UserID, msg.UserID) - // Test with closed connection + // Test sending to closed connection clientConn.Close() err = m.sendCompleteMessage(conn, clientConn) assert.NoError(t, err) @@ -1013,19 +1003,19 @@ func TestSendCompleteMessage_ClosedConnection(t *testing.T) { func TestSendCompleteMessage_ErrorConditions(t *testing.T) { tests := []struct { name string - setupConn func(*Connection, *websocket.Conn) + setupConn func(*Connection, *WSConnLock) expectedError bool }{ { name: "connection already marked as closed", - setupConn: func(conn *Connection, _ *websocket.Conn) { + setupConn: func(conn *Connection, _ *WSConnLock) { conn.closed = true }, expectedError: false, }, { name: "normal closure", - setupConn: func(_ *Connection, clientConn *websocket.Conn) { + setupConn: func(_ *Connection, clientConn *WSConnLock) { //nolint:errcheck clientConn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) @@ -1035,7 +1025,7 @@ func TestSendCompleteMessage_ErrorConditions(t *testing.T) { }, { name: "unexpected close error", - setupConn: func(_ *Connection, clientConn *websocket.Conn) { + setupConn: func(_ *Connection, clientConn *WSConnLock) { //nolint:errcheck clientConn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseProtocolError, "")) @@ -1049,6 +1039,7 @@ func TestSendCompleteMessage_ErrorConditions(t *testing.T) { t.Run(tt.name, func(t *testing.T) { m := NewMultiplexer(kubeconfig.NewContextStore()) clientConn, clientServer := createTestWebSocketConnection() + defer clientServer.Close() conn := &Connection{ @@ -1070,74 +1061,202 @@ func TestSendCompleteMessage_ErrorConditions(t *testing.T) { } } -func createMockKubeAPIServer() *httptest.Server { - server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { +func TestMonitorConnection_Reconnect(t *testing.T) { + contextStore := kubeconfig.NewContextStore() + m := NewMultiplexer(contextStore) + + // Create a server that will accept the connection and then close it + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { upgrader := websocket.Upgrader{ CheckOrigin: func(r *http.Request) bool { return true }, } - c, err := upgrader.Upgrade(w, r, nil) - if err != nil { - return - } + ws, err := upgrader.Upgrade(w, r, nil) + require.NoError(t, err) - defer c.Close() + defer ws.Close() - // Echo messages back - for { - _, msg, err := c.ReadMessage() - if err != nil { - break - } - if err := c.WriteMessage(websocket.TextMessage, msg); err != nil { - break - } - } + // Keep connection alive briefly + time.Sleep(100 * time.Millisecond) })) - // Configure the test client to accept the test server's TLS certificate - server.Client().Transport.(*http.Transport).TLSClientConfig = &tls.Config{ - InsecureSkipVerify: true, //nolint:gosec - } + defer server.Close() - return server + clientConn, clientServer := createTestWebSocketConnection() + defer clientServer.Close() + + conn := createTestConnection("test-cluster", "test-user", "/api/v1/services", "", clientConn) + + wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + tlsConfig := &tls.Config{InsecureSkipVerify: true} //nolint:gosec + + ws, err := m.dialWebSocket(wsURL, tlsConfig, "") + require.NoError(t, err) + + conn.WSConn = ws + + // Start monitoring in a goroutine + go m.monitorConnection(conn) + + // Wait for state transitions + time.Sleep(300 * time.Millisecond) + + // Verify connection status, it should be in error state or connecting + assert.Contains(t, []ConnectionState{StateError, StateConnecting}, conn.Status.State) + + // Clean up + close(conn.Done) } -func createTestWebSocketConnection() (*websocket.Conn, *httptest.Server) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - upgrader := websocket.Upgrader{} - c, err := upgrader.Upgrade(w, r, nil) - if err != nil { - return - } +func TestWriteMessageToCluster(t *testing.T) { + m := NewMultiplexer(kubeconfig.NewContextStore()) + clusterConn, clusterServer := createTestWebSocketConnection() - defer c.Close() + defer clusterServer.Close() - for { - mt, message, err := c.ReadMessage() - if err != nil { - break - } + conn := &Connection{ + ClusterID: "test-cluster", + WSConn: clusterConn.conn, + } - err = c.WriteMessage(mt, message) - if err != nil { - break - } - } - })) + testMessage := []byte("Hello, Cluster!") - wsURL := "ws" + strings.TrimPrefix(server.URL, "http") - dialer := newTestDialer() + // Capture the message sent to the cluster + var receivedMessage []byte - ws, resp, err := dialer.Dial(wsURL, nil) - if err != nil { - panic(err) + done := make(chan bool) + + go func() { + _, receivedMessage, _ = clusterConn.conn.ReadMessage() + done <- true + }() + + err := m.writeMessageToCluster(conn, testMessage) + assert.NoError(t, err) + + select { + case <-done: + assert.Equal(t, testMessage, receivedMessage) + case <-time.After(time.Second): + t.Fatal("Test timed out") } - if resp != nil && resp.Body != nil { - defer resp.Body.Close() + // Test error case + clusterConn.Close() + + err = m.writeMessageToCluster(conn, testMessage) + + assert.Error(t, err) + assert.Equal(t, StateError, conn.Status.State) +} + +func TestHandleClusterMessages(t *testing.T) { + m := NewMultiplexer(kubeconfig.NewContextStore()) + clientConn, clientServer := createTestWebSocketConnection() + + defer clientServer.Close() + + wsConn, wsServer := createTestWebSocketConnection() + defer wsServer.Close() + + conn := createTestConnection("test-cluster", "test-user", "/api/v1/pods", "watch=true", clientConn) + conn.WSConn = wsConn.conn + + done := make(chan struct{}) + go func() { + m.handleClusterMessages(conn, clientConn) + close(done) + }() + + // Send a test message from the cluster + testMessage := []byte(`{"metadata":{"resourceVersion":"1"},"kind":"Pod","apiVersion":"v1","metadata":{"name":"test-pod"}}`) //nolint:lll + err := wsConn.WriteMessage(websocket.TextMessage, testMessage) + require.NoError(t, err) + + // Read the message from the client connection + var msg Message + err = clientConn.ReadJSON(&msg) + require.NoError(t, err) + + assert.Equal(t, "test-cluster", msg.ClusterID) + assert.Equal(t, "/api/v1/pods", msg.Path) + assert.Equal(t, "watch=true", msg.Query) + assert.Equal(t, "test-user", msg.UserID) + + // Close the connection + wsConn.Close() + + // Wait for handleClusterMessages to finish + select { + case <-done: + // Function completed successfully + case <-time.After(5 * time.Second): + t.Fatal("Test timed out") } +} + +func TestSendCompleteMessage(t *testing.T) { + m := NewMultiplexer(kubeconfig.NewContextStore()) + clientConn, clientServer := createTestWebSocketConnection() + + defer clientServer.Close() + + conn := createTestConnection("test-cluster-1", "test-user-1", "/api/v1/pods", "", clientConn) + + // Test sending complete message + err := m.sendCompleteMessage(conn, clientConn) + assert.NoError(t, err) + + // Verify the complete message was sent + var msg Message + err = clientConn.ReadJSON(&msg) + require.NoError(t, err) + assert.Equal(t, "COMPLETE", msg.Type) + assert.Equal(t, conn.ClusterID, msg.ClusterID) + assert.Equal(t, conn.Path, msg.Path) + assert.Equal(t, conn.Query, msg.Query) + assert.Equal(t, conn.UserID, msg.UserID) + + // Test sending to closed connection + conn.closed = true + err = m.sendCompleteMessage(conn, clientConn) + assert.NoError(t, err) // Should return nil for closed connection +} + +func TestSendDataMessage(t *testing.T) { + m := NewMultiplexer(kubeconfig.NewContextStore()) + clientConn, clientServer := createTestWebSocketConnection() + + defer clientServer.Close() + + conn := createTestConnection("test-cluster", "test-user", "/api/v1/pods", "", clientConn) + + // Test sending a text message + textMsg := []byte("Hello, World!") + err := m.sendDataMessage(conn, clientConn, websocket.TextMessage, textMsg) + assert.NoError(t, err) + + // Verify text message + var msg Message + err = clientConn.ReadJSON(&msg) + require.NoError(t, err) + assert.Equal(t, string(textMsg), msg.Data) + assert.False(t, msg.Binary) + + // Test sending a binary message + binaryMsg := []byte{0x01, 0x02, 0x03} + err = m.sendDataMessage(conn, clientConn, websocket.BinaryMessage, binaryMsg) + assert.NoError(t, err) - return ws, server + // Verify binary message + err = clientConn.ReadJSON(&msg) + require.NoError(t, err) + assert.Equal(t, base64.StdEncoding.EncodeToString(binaryMsg), msg.Data) + assert.True(t, msg.Binary) + + // Test sending to closed connection + conn.closed = true + err = m.sendDataMessage(conn, clientConn, websocket.TextMessage, textMsg) + assert.NoError(t, err) // Should return nil even for closed connection }