diff --git a/backend/cmd/headlamp.go b/backend/cmd/headlamp.go index 1251f84e56..06074e9881 100644 --- a/backend/cmd/headlamp.go +++ b/backend/cmd/headlamp.go @@ -59,6 +59,7 @@ type HeadlampConfig struct { proxyURLs []string cache cache.Cache[interface{}] kubeConfigStore kubeconfig.ContextStore + multiplexer *Multiplexer } const DrainNodeCacheTTL = 20 // seconds @@ -1599,6 +1600,9 @@ func (c *HeadlampConfig) addClusterSetupRoute(r *mux.Router) { // Rename a cluster r.HandleFunc("/cluster/{name}", c.renameCluster).Methods("PUT") + + // Websocket connections + r.HandleFunc("/wsMutliplexer", c.multiplexer.HandleClientWebSocket) } /* diff --git a/backend/cmd/multiplexer.go b/backend/cmd/multiplexer.go new file mode 100644 index 0000000000..ad37f91359 --- /dev/null +++ b/backend/cmd/multiplexer.go @@ -0,0 +1,632 @@ +package main + +import ( + "crypto/tls" + "encoding/base64" + "encoding/json" + "fmt" + "net/http" + "net/url" + "sync" + "time" + + "github.com/gorilla/websocket" + "github.com/headlamp-k8s/headlamp/backend/pkg/kubeconfig" + "github.com/headlamp-k8s/headlamp/backend/pkg/logger" + "k8s.io/client-go/rest" +) + +const ( + // StateConnecting is the state when the connection is being established. + StateConnecting ConnectionState = "connecting" + // StateConnected is the state when the connection is established. + StateConnected ConnectionState = "connected" + // StateError is the state when the connection has an error. + StateError ConnectionState = "error" + // StateClosed is the state when the connection is closed. + StateClosed ConnectionState = "closed" +) + +const ( + // HeartbeatInterval is the interval at which the multiplexer sends heartbeat messages to the client. + HeartbeatInterval = 30 * time.Second + // HandshakeTimeout is the timeout for the handshake with the client. + HandshakeTimeout = 45 * time.Second + // CleanupRoutineInterval is the interval at which the multiplexer cleans up unused connections. + CleanupRoutineInterval = 5 * time.Minute +) + +// ConnectionState represents the current state of a connection. +type ConnectionState string + +type ConnectionStatus struct { + // State is the current state of the connection. + State ConnectionState `json:"state"` + // Error is the error message of the connection. + Error string `json:"error,omitempty"` + // LastMsg is the last message time of the connection. + LastMsg time.Time `json:"lastMsg"` +} + +// Connection represents a WebSocket connection to a Kubernetes cluster. +type Connection struct { + // ClusterID is the ID of the cluster. + ClusterID string + // UserID is the ID of the user. + UserID string + // Path is the path of the connection. + Path string + // Query is the query of the connection. + Query string + // WSConn is the WebSocket connection to the cluster. + WSConn *websocket.Conn + // Status is the status of the connection. + Status ConnectionStatus + // Client is the WebSocket connection to the client. + Client *websocket.Conn + // Done is a channel to signal when the connection is done. + Done chan struct{} + // mu is a mutex to synchronize access to the connection. + mu sync.RWMutex +} + +// Message represents a WebSocket message structure. +type Message struct { + // ClusterID is the ID of the cluster. + ClusterID string `json:"clusterId"` + // Path is the path of the connection. + Path string `json:"path"` + // Query is the query of the connection. + Query string `json:"query"` + // UserID is the ID of the user. + UserID string `json:"userId"` + // Data contains the message payload. + Data []byte `json:"data,omitempty"` +} + +// Multiplexer manages multiple WebSocket connections. +type Multiplexer struct { + // connections is a map of connections indexed by the cluster ID and path. + connections map[string]*Connection + // mutex is a mutex to synchronize access to the connections. + mutex sync.RWMutex + // upgrader is the WebSocket upgrader. + upgrader websocket.Upgrader + // kubeConfigStore is the kubeconfig store. + kubeConfigStore kubeconfig.ContextStore +} + +// NewMultiplexer creates a new Multiplexer instance. +func NewMultiplexer(kubeConfigStore kubeconfig.ContextStore) *Multiplexer { + return &Multiplexer{ + connections: make(map[string]*Connection), + kubeConfigStore: kubeConfigStore, + upgrader: websocket.Upgrader{ + CheckOrigin: func(r *http.Request) bool { + return true + }, + }, + } +} + +// updateStatus updates the status of a connection and notifies the client. +func (c *Connection) updateStatus(state ConnectionState, err error) { + c.mu.Lock() + defer c.mu.Unlock() + + c.Status.State = state + c.Status.LastMsg = time.Now() + + if err != nil { + c.Status.Error = err.Error() + } else { + c.Status.Error = "" + } + + if c.Client != nil { + statusData := struct { + State string `json:"state"` + Error string `json:"error"` + }{ + State: string(state), + Error: c.Status.Error, + } + + jsonData, jsonErr := json.Marshal(statusData) + if jsonErr != nil { + logger.Log(logger.LevelError, map[string]string{"clusterID": c.ClusterID}, jsonErr, "marshaling status message") + + return + } + + statusMsg := Message{ + ClusterID: c.ClusterID, + Path: c.Path, + Data: jsonData, + } + + err := c.Client.WriteJSON(statusMsg) + if err != nil { + logger.Log(logger.LevelError, map[string]string{"clusterID": c.ClusterID}, err, "writing status message to client") + } + } +} + +// establishClusterConnection creates a new WebSocket connection to a Kubernetes cluster. +func (m *Multiplexer) establishClusterConnection( + clusterID, + userID, + path, + query string, + clientConn *websocket.Conn, +) (*Connection, error) { + config, err := m.getClusterConfigWithFallback(clusterID, userID) + if err != nil { + logger.Log(logger.LevelError, map[string]string{"clusterID": clusterID}, err, "getting cluster config") + return nil, err + } + + connection := m.createConnection(clusterID, userID, path, query, clientConn) + + wsURL := createWebSocketURL(config.Host, path, query) + + tlsConfig, err := rest.TLSConfigFor(config) + if err != nil { + connection.updateStatus(StateError, err) + + return nil, fmt.Errorf("failed to get TLS config: %v", err) + } + + conn, err := m.dialWebSocket(wsURL, tlsConfig, config.Host) + if err != nil { + connection.updateStatus(StateError, err) + + return nil, err + } + + connection.WSConn = conn + connection.updateStatus(StateConnected, nil) + + m.mutex.Lock() + m.connections[clusterID+path] = connection + m.mutex.Unlock() + + go m.monitorConnection(connection) + + return connection, nil +} + +// getClusterConfigWithFallback attempts to get the cluster config, +// falling back to a combined key for stateless clusters. +func (m *Multiplexer) getClusterConfigWithFallback(clusterID, userID string) (*rest.Config, error) { + // Try to get config for stateful cluster first. + config, err := m.getClusterConfig(clusterID) + if err != nil { + // If not found, try with the combined key for stateless clusters. + combinedKey := fmt.Sprintf("%s%s", clusterID, userID) + + config, err = m.getClusterConfig(combinedKey) + if err != nil { + return nil, fmt.Errorf("getting cluster config: %v", err) + } + } + + return config, nil +} + +// createConnection creates a new Connection instance. +func (m *Multiplexer) createConnection( + clusterID, + userID, + path, + query string, + clientConn *websocket.Conn, +) *Connection { + return &Connection{ + ClusterID: clusterID, + UserID: userID, + Path: path, + Query: query, + Client: clientConn, + Done: make(chan struct{}), + Status: ConnectionStatus{ + State: StateConnecting, + LastMsg: time.Now(), + }, + } +} + +// dialWebSocket establishes a WebSocket connection. +func (m *Multiplexer) dialWebSocket(wsURL string, tlsConfig *tls.Config, host string) (*websocket.Conn, error) { + dialer := websocket.Dialer{ + TLSClientConfig: tlsConfig, + HandshakeTimeout: HandshakeTimeout, + } + + conn, resp, err := dialer.Dial( + wsURL, + http.Header{ + "Origin": {host}, + }, + ) + if err != nil { + logger.Log(logger.LevelError, nil, err, "dialing WebSocket") + // We only attempt to close the response body if there was an error and resp is not nil. + // In the successful case (when err is nil), the resp will actually be nil for WebSocket connections, + // so we don't need to close anything. + if resp != nil { + defer resp.Body.Close() + } + + return nil, fmt.Errorf("dialing WebSocket: %v", err) + } + + return conn, nil +} + +// monitorConnection monitors the health of a connection and attempts to reconnect if necessary. +func (m *Multiplexer) monitorConnection(conn *Connection) { + heartbeat := time.NewTicker(HeartbeatInterval) + defer heartbeat.Stop() + + for { + select { + case <-conn.Done: + conn.updateStatus(StateClosed, nil) + + return + case <-heartbeat.C: + if err := conn.WSConn.WriteMessage(websocket.PingMessage, nil); err != nil { + conn.updateStatus(StateError, fmt.Errorf("heartbeat failed: %v", err)) + + if newConn, err := m.reconnect(conn); err != nil { + logger.Log(logger.LevelError, map[string]string{"clusterID": conn.ClusterID}, err, "reconnecting to cluster") + } else { + conn = newConn + } + } + } + } +} + +// reconnect attempts to reestablish a connection. +func (m *Multiplexer) reconnect(conn *Connection) (*Connection, error) { + if conn.WSConn != nil { + conn.WSConn.Close() + } + + newConn, err := m.establishClusterConnection( + conn.ClusterID, + conn.UserID, + conn.Path, + conn.Query, + conn.Client, + ) + if err != nil { + logger.Log(logger.LevelError, map[string]string{"clusterID": conn.ClusterID}, err, "reconnecting to cluster") + + return nil, err + } + + m.mutex.Lock() + m.connections[conn.ClusterID+conn.Path] = newConn + m.mutex.Unlock() + + return newConn, nil +} + +// HandleClientWebSocket handles incoming WebSocket connections from clients. +// HandleClientWebSocket handles incoming WebSocket connections from clients. +func (m *Multiplexer) HandleClientWebSocket(w http.ResponseWriter, r *http.Request) { + clientConn, err := m.upgrader.Upgrade(w, r, nil) + if err != nil { + logger.Log(logger.LevelError, nil, err, "upgrading connection") + return + } + + defer clientConn.Close() + + for { + msg, err := m.readClientMessage(clientConn) + if err != nil { + break + } + + // Check if it's a close message + if msg.Data != nil && len(msg.Data) > 0 && string(msg.Data) == "close" { + err := m.CloseConnection(msg.ClusterID, msg.Path, msg.UserID) + if err != nil { + logger.Log( + logger.LevelError, + map[string]string{"clusterID": msg.ClusterID, "UserID": msg.UserID}, + err, + "closing connection", + ) + } + + continue + } + + conn, err := m.getOrCreateConnection(msg, clientConn) + if err != nil { + m.handleConnectionError(clientConn, msg, err) + + continue + } + + if len(msg.Data) > 0 && conn.Status.State == StateConnected { + err = m.writeMessageToCluster(conn, msg.Data) + if err != nil { + continue + } + } + } + + m.cleanupConnections() +} + +// readClientMessage reads a message from the client WebSocket connection. +func (m *Multiplexer) readClientMessage(clientConn *websocket.Conn) (Message, error) { + var msg Message + + _, rawMessage, err := clientConn.ReadMessage() + if err != nil { + logger.Log(logger.LevelError, nil, err, "reading message") + + return Message{}, err + } + + err = json.Unmarshal(rawMessage, &msg) + if err != nil { + logger.Log(logger.LevelError, nil, err, "unmarshaling message") + + return Message{}, err + } + + return msg, nil +} + +// getOrCreateConnection gets an existing connection or creates a new one if it doesn't exist. +func (m *Multiplexer) getOrCreateConnection(msg Message, clientConn *websocket.Conn) (*Connection, error) { + connKey := fmt.Sprintf("%s:%s:%s", msg.ClusterID, msg.Path, msg.UserID) + + m.mutex.RLock() + conn, exists := m.connections[connKey] + m.mutex.RUnlock() + + if !exists { + var err error + + conn, err = m.establishClusterConnection(msg.ClusterID, msg.UserID, msg.Path, msg.Query, clientConn) + if err != nil { + logger.Log( + logger.LevelError, + map[string]string{"clusterID": msg.ClusterID, "UserID": msg.UserID}, + err, + "establishing cluster connection", + ) + + return nil, err + } + + go m.handleClusterMessages(conn, clientConn) + } + + return conn, nil +} + +// handleConnectionError handles errors that occur when establishing a connection. +func (m *Multiplexer) handleConnectionError(clientConn *websocket.Conn, msg Message, err error) { + errorMsg := struct { + ClusterID string `json:"clusterId"` + Error string `json:"error"` + }{ + ClusterID: msg.ClusterID, + Error: err.Error(), + } + + if err = clientConn.WriteJSON(errorMsg); err != nil { + logger.Log( + logger.LevelError, + map[string]string{"clusterID": msg.ClusterID}, + err, + "writing error message to client", + ) + } + + logger.Log(logger.LevelError, map[string]string{"clusterID": msg.ClusterID}, err, "establishing cluster connection") +} + +// writeMessageToCluster writes a message to the cluster WebSocket connection. +func (m *Multiplexer) writeMessageToCluster(conn *Connection, data []byte) error { + err := conn.WSConn.WriteMessage(websocket.BinaryMessage, data) + if err != nil { + conn.updateStatus(StateError, err) + logger.Log( + logger.LevelError, + map[string]string{"clusterID": conn.ClusterID}, + err, + "writing message to cluster", + ) + + return err + } + + return nil +} + +// handleClusterMessages handles messages from a cluster connection. +func (m *Multiplexer) handleClusterMessages(conn *Connection, clientConn *websocket.Conn) { + defer func() { + conn.updateStatus(StateClosed, nil) + conn.WSConn.Close() + }() + + for { + select { + case <-conn.Done: + return + default: + if err := m.processClusterMessage(conn, clientConn); err != nil { + return + } + } + } +} + +// processClusterMessage processes a message from a cluster connection. +func (m *Multiplexer) processClusterMessage(conn *Connection, clientConn *websocket.Conn) error { + messageType, message, err := conn.WSConn.ReadMessage() + if err != nil { + m.handleReadError(conn, err) + + return err + } + + wrapperMsg := m.createWrapperMessage(conn, messageType, message) + + if err := clientConn.WriteJSON(wrapperMsg); err != nil { + m.handleWriteError(conn, err) + + return err + } + + conn.mu.Lock() + conn.Status.LastMsg = time.Now() + conn.mu.Unlock() + + return nil +} + +// createWrapperMessage creates a wrapper message for a cluster connection. +func (m *Multiplexer) createWrapperMessage(conn *Connection, messageType int, message []byte) struct { + ClusterID string `json:"clusterId"` + Path string `json:"path"` + Query string `json:"query"` + UserID string `json:"userId"` + Data string `json:"data"` + Binary bool `json:"binary"` +} { + wrapperMsg := struct { + ClusterID string `json:"clusterId"` + Path string `json:"path"` + Query string `json:"query"` + UserID string `json:"userId"` + Data string `json:"data"` + Binary bool `json:"binary"` + }{ + ClusterID: conn.ClusterID, + Path: conn.Path, + Query: conn.Query, + UserID: conn.UserID, + Binary: messageType == websocket.BinaryMessage, + } + + if messageType == websocket.BinaryMessage { + wrapperMsg.Data = base64.StdEncoding.EncodeToString(message) + } else { + wrapperMsg.Data = string(message) + } + + return wrapperMsg +} + +// handleReadError handles errors that occur when reading a message from a cluster connection. +func (m *Multiplexer) handleReadError(conn *Connection, err error) { + conn.updateStatus(StateError, err) + logger.Log( + logger.LevelError, + map[string]string{"clusterID": conn.ClusterID, "UserID": conn.UserID}, + err, + "reading message from cluster", + ) +} + +// handleWriteError handles errors that occur when writing a message to a client connection. +func (m *Multiplexer) handleWriteError(conn *Connection, err error) { + conn.updateStatus(StateError, err) + logger.Log( + logger.LevelError, + map[string]string{"clusterID": conn.ClusterID, "UserID": conn.UserID}, + err, + "writing message to client", + ) +} + +// cleanupConnections closes and removes all connections. +func (m *Multiplexer) cleanupConnections() { + m.mutex.Lock() + defer m.mutex.Unlock() + + for key, conn := range m.connections { + conn.updateStatus(StateClosed, nil) + close(conn.Done) + + if conn.WSConn != nil { + conn.WSConn.Close() + } + + delete(m.connections, key) + } +} + +// getClusterConfig retrieves the REST config for a given cluster. +func (m *Multiplexer) getClusterConfig(clusterID string) (*rest.Config, error) { + ctxtProxy, err := m.kubeConfigStore.GetContext(clusterID) + if err != nil { + return nil, fmt.Errorf("getting context: %v", err) + } + + clientConfig, err := ctxtProxy.RESTConfig() + if err != nil { + return nil, fmt.Errorf("getting REST config: %v", err) + } + + return clientConfig, nil +} + +// CloseConnection closes a specific connection based on its identifier. +func (m *Multiplexer) CloseConnection(clusterID, path, userID string) error { + connKey := fmt.Sprintf("%s:%s:%s", clusterID, path, userID) + + m.mutex.Lock() + defer m.mutex.Unlock() + + conn, exists := m.connections[connKey] + if !exists { + return fmt.Errorf("connection not found for key: %s", connKey) + } + + // Signal the connection to close + close(conn.Done) + + // Close the WebSocket connection + if conn.WSConn != nil { + if err := conn.WSConn.Close(); err != nil { + logger.Log( + logger.LevelError, + map[string]string{"clusterID": clusterID, "userID": userID}, + err, + "closing WebSocket connection", + ) + } + } + + // Update the connection status + conn.updateStatus(StateClosed, nil) + + // Remove the connection from the map + delete(m.connections, connKey) + + return nil +} + +// createWebSocketURL creates a WebSocket URL from the given parameters. +func createWebSocketURL(host, path, query string) string { + u, _ := url.Parse(host) + u.Scheme = "wss" + u.Path = path + u.RawQuery = query + + return u.String() +} diff --git a/backend/cmd/multiplexer_test.go b/backend/cmd/multiplexer_test.go new file mode 100644 index 0000000000..058e01377b --- /dev/null +++ b/backend/cmd/multiplexer_test.go @@ -0,0 +1,426 @@ +package main + +import ( + "crypto/tls" + "encoding/json" + "fmt" + "net" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/gorilla/websocket" + "github.com/headlamp-k8s/headlamp/backend/pkg/kubeconfig" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "k8s.io/client-go/tools/clientcmd/api" +) + +func newTestDialer() *websocket.Dialer { + return &websocket.Dialer{ + NetDial: net.Dial, + HandshakeTimeout: 45 * time.Second, + } +} + +func TestNewMultiplexer(t *testing.T) { + store := kubeconfig.NewContextStore() + m := NewMultiplexer(store) + + assert.NotNil(t, m) + assert.Equal(t, store, m.kubeConfigStore) + assert.NotNil(t, m.connections) + assert.NotNil(t, m.upgrader) +} + +func TestHandleClientWebSocket(t *testing.T) { + store := kubeconfig.NewContextStore() + m := NewMultiplexer(store) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + m.HandleClientWebSocket(w, r) + })) + defer server.Close() + + url := "ws" + strings.TrimPrefix(server.URL, "http") + + dialer := newTestDialer() + + conn, resp, err := dialer.Dial(url, nil) + if err == nil { + defer conn.Close() + } + + if resp != nil && resp.Body != nil { + defer resp.Body.Close() + } + + assert.NoError(t, err, "Should successfully establish WebSocket connection") +} + +func TestGetClusterConfigWithFallback(t *testing.T) { + store := kubeconfig.NewContextStore() + m := NewMultiplexer(store) + + // Add a mock cluster config + err := store.AddContext(&kubeconfig.Context{ + Name: "test-cluster", + Cluster: &api.Cluster{ + Server: "https://test-cluster.example.com", + }, + }) + require.NoError(t, err) + + config, err := m.getClusterConfigWithFallback("test-cluster", "test-user") + assert.NoError(t, err) + assert.NotNil(t, config) + + // Test fallback + config, err = m.getClusterConfigWithFallback("non-existent", "test-user") + assert.Error(t, err) + assert.Nil(t, config) +} + +func TestCreateConnection(t *testing.T) { + m := NewMultiplexer(kubeconfig.NewContextStore()) + clientConn, _ := createTestWebSocketConnection() + + // Add RequestID to the createConnection call + conn := m.createConnection("test-cluster", "test-user", "/api/v1/pods", "", clientConn) + assert.NotNil(t, conn) + assert.Equal(t, "test-cluster", conn.ClusterID) + assert.Equal(t, "test-user", conn.UserID) + assert.Equal(t, "/api/v1/pods", conn.Path) + assert.Equal(t, StateConnecting, conn.Status.State) +} + +func TestDialWebSocket(t *testing.T) { + m := NewMultiplexer(kubeconfig.NewContextStore()) + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + upgrader := websocket.Upgrader{ + CheckOrigin: func(r *http.Request) bool { + return true // Allow all connections for testing + }, + } + + c, err := upgrader.Upgrade(w, r, nil) + if err != nil { + t.Logf("Upgrade error: %v", err) + return + } + + defer c.Close() + // Echo incoming messages back to the client + for { + mt, message, err := c.ReadMessage() + if err != nil { + break + } + err = c.WriteMessage(mt, message) + if err != nil { + break + } + } + })) + + defer server.Close() + + wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + conn, err := m.dialWebSocket(wsURL, &tls.Config{InsecureSkipVerify: true}, server.URL) //nolint:gosec + assert.NoError(t, err) + assert.NotNil(t, conn) + + if conn != nil { + conn.Close() + } +} + +func TestMonitorConnection(t *testing.T) { + m := NewMultiplexer(kubeconfig.NewContextStore()) + clientConn, _ := createTestWebSocketConnection() + + // Updated createConnection call with all required arguments + conn := m.createConnection("test-cluster", "test-user", "/api/v1/pods", "", clientConn) + conn.WSConn, _ = createTestWebSocketConnection() + + done := make(chan struct{}) + go func() { + m.monitorConnection(conn) + close(done) + }() + + time.Sleep(100 * time.Millisecond) + close(conn.Done) + <-done + + assert.Equal(t, StateClosed, conn.Status.State) +} + +//nolint:funlen +func TestHandleClusterMessages(t *testing.T) { + m := NewMultiplexer(kubeconfig.NewContextStore()) + clientConn, clientServer := createTestWebSocketConnection() + + defer clientServer.Close() + + clusterConn, clusterServer := createTestWebSocketConnection() + + defer clusterServer.Close() + + // Add RequestID to the createConnection call + conn := m.createConnection("minikube", "test-user", "/api/v1/pods", "watch=true", clientConn) + conn.WSConn = clusterConn + + done := make(chan struct{}) + go func() { + m.handleClusterMessages(conn, clientConn) + close(done) + }() + + // Send a test message from the cluster + testMessage := []byte(`{"kind":"Pod","apiVersion":"v1","metadata":{"name":"test-pod"}}`) + err := clusterConn.WriteMessage(websocket.TextMessage, testMessage) + require.NoError(t, err) + + // Read the message from the client connection + _, receivedMessage, err := clientConn.ReadMessage() + require.NoError(t, err) + + var wrapperMsg struct { + ClusterID string `json:"clusterId"` + Path string `json:"path"` + Query string `json:"query"` + UserID string `json:"userId"` + RequestID string `json:"requestId"` + Data string `json:"data"` + Binary bool `json:"binary"` + } + + err = json.Unmarshal(receivedMessage, &wrapperMsg) + require.NoError(t, err) + + assert.Equal(t, "minikube", wrapperMsg.ClusterID) + assert.Equal(t, "/api/v1/pods", wrapperMsg.Path) + assert.Equal(t, "watch=true", wrapperMsg.Query) + assert.Equal(t, "test-user", wrapperMsg.UserID) + assert.False(t, wrapperMsg.Binary) + + // Parse the Data field separately + var podData map[string]interface{} + err = json.Unmarshal([]byte(wrapperMsg.Data), &podData) + require.NoError(t, err) + assert.Equal(t, "Pod", podData["kind"]) + assert.Equal(t, "v1", podData["apiVersion"]) + + // Close the connection to trigger the end of handleClusterMessages + conn.WSConn.Close() + + // Wait for handleClusterMessages to finish + select { + case <-done: + // Function completed successfully + case <-time.After(5 * time.Second): + t.Fatal("Test timed out") + } + + assert.Equal(t, StateClosed, conn.Status.State) +} + +func TestCleanupConnections(t *testing.T) { + m := NewMultiplexer(kubeconfig.NewContextStore()) + clientConn, _ := createTestWebSocketConnection() + // Add RequestID to the createConnection call + conn := m.createConnection("test-cluster", "test-user", "/api/v1/pods", "", clientConn) + conn.WSConn, _ = createTestWebSocketConnection() + + // Use the new connection key format + connKey := "test-cluster:/api/v1/pods:test-request-id" + m.connections[connKey] = conn + + m.cleanupConnections() + + assert.Empty(t, m.connections) + assert.Equal(t, StateClosed, conn.Status.State) +} + +func createTestWebSocketConnection() (*websocket.Conn, *httptest.Server) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + upgrader := websocket.Upgrader{} + c, err := upgrader.Upgrade(w, r, nil) + if err != nil { + return + } + + defer c.Close() + + for { + mt, message, err := c.ReadMessage() + if err != nil { + break + } + + err = c.WriteMessage(mt, message) + if err != nil { + break + } + } + })) + + wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + dialer := newTestDialer() + + ws, resp, err := dialer.Dial(wsURL, nil) + if err != nil { + panic(err) + } + + if resp != nil && resp.Body != nil { + defer resp.Body.Close() + } + + return ws, server +} + +func TestCloseConnection(t *testing.T) { + m := NewMultiplexer(kubeconfig.NewContextStore()) + clientConn, _ := createTestWebSocketConnection() + conn := m.createConnection("test-cluster", "test-user", "/api/v1/pods", "", clientConn) + conn.WSConn, _ = createTestWebSocketConnection() + + connKey := "test-cluster:/api/v1/pods:test-user" + m.connections[connKey] = conn + + err := m.CloseConnection("test-cluster", "/api/v1/pods", "test-user") + assert.NoError(t, err) + assert.Empty(t, m.connections) + assert.Equal(t, StateClosed, conn.Status.State) + + // Test closing a non-existent connection + err = m.CloseConnection("non-existent", "/api/v1/pods", "test-user") + assert.Error(t, err) +} + +func TestCreateWrapperMessage(t *testing.T) { + m := NewMultiplexer(kubeconfig.NewContextStore()) + conn := &Connection{ + ClusterID: "test-cluster", + Path: "/api/v1/pods", + Query: "watch=true", + UserID: "test-user", + } + + // Test text message + textMsg := []byte("Hello, World!") + wrapperMsg := m.createWrapperMessage(conn, websocket.TextMessage, textMsg) + assert.Equal(t, "test-cluster", wrapperMsg.ClusterID) + assert.Equal(t, "/api/v1/pods", wrapperMsg.Path) + assert.Equal(t, "watch=true", wrapperMsg.Query) + assert.Equal(t, "test-user", wrapperMsg.UserID) + assert.Equal(t, "Hello, World!", wrapperMsg.Data) + assert.False(t, wrapperMsg.Binary) + + // Test binary message + binaryMsg := []byte{0x01, 0x02, 0x03} + wrapperMsg = m.createWrapperMessage(conn, websocket.BinaryMessage, binaryMsg) + assert.Equal(t, "test-cluster", wrapperMsg.ClusterID) + assert.Equal(t, "/api/v1/pods", wrapperMsg.Path) + assert.Equal(t, "watch=true", wrapperMsg.Query) + assert.Equal(t, "test-user", wrapperMsg.UserID) + assert.Equal(t, "AQID", wrapperMsg.Data) // Base64 encoded + assert.True(t, wrapperMsg.Binary) +} + +func TestHandleConnectionError(t *testing.T) { + m := NewMultiplexer(kubeconfig.NewContextStore()) + clientConn, clientServer := createTestWebSocketConnection() + + defer clientServer.Close() + + msg := Message{ + ClusterID: "test-cluster", + Path: "/api/v1/pods", + UserID: "test-user", + } + + testError := fmt.Errorf("test error") + + // Capture the error message sent to the client + var receivedMsg struct { + ClusterID string `json:"clusterId"` + Error string `json:"error"` + } + + done := make(chan bool) + go func() { + _, rawMsg, err := clientConn.ReadMessage() + if err != nil { + t.Errorf("Error reading message: %v", err) + done <- true + + return + } + + err = json.Unmarshal(rawMsg, &receivedMsg) + if err != nil { + t.Errorf("Error unmarshaling message: %v", err) + done <- true + + return + } + + done <- true + }() + + m.handleConnectionError(clientConn, msg, testError) + + select { + case <-done: + assert.Equal(t, "test-cluster", receivedMsg.ClusterID) + assert.Equal(t, "test error", receivedMsg.Error) + case <-time.After(time.Second): + t.Fatal("Test timed out") + } +} + +func TestWriteMessageToCluster(t *testing.T) { + m := NewMultiplexer(kubeconfig.NewContextStore()) + clusterConn, clusterServer := createTestWebSocketConnection() + + defer clusterServer.Close() + + conn := &Connection{ + ClusterID: "test-cluster", + WSConn: clusterConn, + } + + testMessage := []byte("Hello, Cluster!") + + // Capture the message sent to the cluster + var receivedMessage []byte + + done := make(chan bool) + go func() { + _, receivedMessage, _ = clusterConn.ReadMessage() + done <- true + }() + + err := m.writeMessageToCluster(conn, testMessage) + assert.NoError(t, err) + + select { + case <-done: + assert.Equal(t, testMessage, receivedMessage) + case <-time.After(time.Second): + t.Fatal("Test timed out") + } + + // Test error case + clusterConn.Close() + + err = m.writeMessageToCluster(conn, testMessage) + + assert.Error(t, err) + assert.Equal(t, StateError, conn.Status.State) +} diff --git a/backend/cmd/server.go b/backend/cmd/server.go index 736e10a365..bca276986a 100644 --- a/backend/cmd/server.go +++ b/backend/cmd/server.go @@ -34,6 +34,7 @@ func main() { cache := cache.New[interface{}]() kubeConfigStore := kubeconfig.NewContextStore() + multiplexer := NewMultiplexer(kubeConfigStore) StartHeadlampServer(&HeadlampConfig{ useInCluster: conf.InCluster, @@ -53,5 +54,6 @@ func main() { enableDynamicClusters: conf.EnableDynamicClusters, cache: cache, kubeConfigStore: kubeConfigStore, + multiplexer: multiplexer, }) }