From bf2ec6d2fbca83700762ad1d1881f0d7ba5a9352 Mon Sep 17 00:00:00 2001 From: Kautilya Tripathi Date: Mon, 21 Oct 2024 11:44:18 +0530 Subject: [PATCH] backend: Add websocket multiplexer This adds websocket multiplexer to the backend. Frontend now make a single websocket call to the backend. Once that connection is established with the backend it will send message to the backend with appropriate data. Backend will open multiple websockets and act as a proxy to the frontend. It will make request to k8s server and return data to frontend. This also adds retry logic if the connection is broken between frontend and backend. Fixes: #1802 Signed-off-by: Kautilya Tripathi --- backend/cmd/headlamp.go | 4 + backend/cmd/multiplexer.go | 632 ++++++++++++++++++++++++++++++++ backend/cmd/multiplexer_test.go | 426 +++++++++++++++++++++ backend/cmd/server.go | 2 + 4 files changed, 1064 insertions(+) create mode 100644 backend/cmd/multiplexer.go create mode 100644 backend/cmd/multiplexer_test.go diff --git a/backend/cmd/headlamp.go b/backend/cmd/headlamp.go index 1251f84e56..06074e9881 100644 --- a/backend/cmd/headlamp.go +++ b/backend/cmd/headlamp.go @@ -59,6 +59,7 @@ type HeadlampConfig struct { proxyURLs []string cache cache.Cache[interface{}] kubeConfigStore kubeconfig.ContextStore + multiplexer *Multiplexer } const DrainNodeCacheTTL = 20 // seconds @@ -1599,6 +1600,9 @@ func (c *HeadlampConfig) addClusterSetupRoute(r *mux.Router) { // Rename a cluster r.HandleFunc("/cluster/{name}", c.renameCluster).Methods("PUT") + + // Websocket connections + r.HandleFunc("/wsMutliplexer", c.multiplexer.HandleClientWebSocket) } /* diff --git a/backend/cmd/multiplexer.go b/backend/cmd/multiplexer.go new file mode 100644 index 0000000000..ad37f91359 --- /dev/null +++ b/backend/cmd/multiplexer.go @@ -0,0 +1,632 @@ +package main + +import ( + "crypto/tls" + "encoding/base64" + "encoding/json" + "fmt" + "net/http" + "net/url" + "sync" + "time" + + "github.com/gorilla/websocket" + "github.com/headlamp-k8s/headlamp/backend/pkg/kubeconfig" + "github.com/headlamp-k8s/headlamp/backend/pkg/logger" + "k8s.io/client-go/rest" +) + +const ( + // StateConnecting is the state when the connection is being established. + StateConnecting ConnectionState = "connecting" + // StateConnected is the state when the connection is established. + StateConnected ConnectionState = "connected" + // StateError is the state when the connection has an error. + StateError ConnectionState = "error" + // StateClosed is the state when the connection is closed. + StateClosed ConnectionState = "closed" +) + +const ( + // HeartbeatInterval is the interval at which the multiplexer sends heartbeat messages to the client. + HeartbeatInterval = 30 * time.Second + // HandshakeTimeout is the timeout for the handshake with the client. + HandshakeTimeout = 45 * time.Second + // CleanupRoutineInterval is the interval at which the multiplexer cleans up unused connections. + CleanupRoutineInterval = 5 * time.Minute +) + +// ConnectionState represents the current state of a connection. +type ConnectionState string + +type ConnectionStatus struct { + // State is the current state of the connection. + State ConnectionState `json:"state"` + // Error is the error message of the connection. + Error string `json:"error,omitempty"` + // LastMsg is the last message time of the connection. + LastMsg time.Time `json:"lastMsg"` +} + +// Connection represents a WebSocket connection to a Kubernetes cluster. +type Connection struct { + // ClusterID is the ID of the cluster. + ClusterID string + // UserID is the ID of the user. + UserID string + // Path is the path of the connection. + Path string + // Query is the query of the connection. + Query string + // WSConn is the WebSocket connection to the cluster. + WSConn *websocket.Conn + // Status is the status of the connection. + Status ConnectionStatus + // Client is the WebSocket connection to the client. + Client *websocket.Conn + // Done is a channel to signal when the connection is done. + Done chan struct{} + // mu is a mutex to synchronize access to the connection. + mu sync.RWMutex +} + +// Message represents a WebSocket message structure. +type Message struct { + // ClusterID is the ID of the cluster. + ClusterID string `json:"clusterId"` + // Path is the path of the connection. + Path string `json:"path"` + // Query is the query of the connection. + Query string `json:"query"` + // UserID is the ID of the user. + UserID string `json:"userId"` + // Data contains the message payload. + Data []byte `json:"data,omitempty"` +} + +// Multiplexer manages multiple WebSocket connections. +type Multiplexer struct { + // connections is a map of connections indexed by the cluster ID and path. + connections map[string]*Connection + // mutex is a mutex to synchronize access to the connections. + mutex sync.RWMutex + // upgrader is the WebSocket upgrader. + upgrader websocket.Upgrader + // kubeConfigStore is the kubeconfig store. + kubeConfigStore kubeconfig.ContextStore +} + +// NewMultiplexer creates a new Multiplexer instance. +func NewMultiplexer(kubeConfigStore kubeconfig.ContextStore) *Multiplexer { + return &Multiplexer{ + connections: make(map[string]*Connection), + kubeConfigStore: kubeConfigStore, + upgrader: websocket.Upgrader{ + CheckOrigin: func(r *http.Request) bool { + return true + }, + }, + } +} + +// updateStatus updates the status of a connection and notifies the client. +func (c *Connection) updateStatus(state ConnectionState, err error) { + c.mu.Lock() + defer c.mu.Unlock() + + c.Status.State = state + c.Status.LastMsg = time.Now() + + if err != nil { + c.Status.Error = err.Error() + } else { + c.Status.Error = "" + } + + if c.Client != nil { + statusData := struct { + State string `json:"state"` + Error string `json:"error"` + }{ + State: string(state), + Error: c.Status.Error, + } + + jsonData, jsonErr := json.Marshal(statusData) + if jsonErr != nil { + logger.Log(logger.LevelError, map[string]string{"clusterID": c.ClusterID}, jsonErr, "marshaling status message") + + return + } + + statusMsg := Message{ + ClusterID: c.ClusterID, + Path: c.Path, + Data: jsonData, + } + + err := c.Client.WriteJSON(statusMsg) + if err != nil { + logger.Log(logger.LevelError, map[string]string{"clusterID": c.ClusterID}, err, "writing status message to client") + } + } +} + +// establishClusterConnection creates a new WebSocket connection to a Kubernetes cluster. +func (m *Multiplexer) establishClusterConnection( + clusterID, + userID, + path, + query string, + clientConn *websocket.Conn, +) (*Connection, error) { + config, err := m.getClusterConfigWithFallback(clusterID, userID) + if err != nil { + logger.Log(logger.LevelError, map[string]string{"clusterID": clusterID}, err, "getting cluster config") + return nil, err + } + + connection := m.createConnection(clusterID, userID, path, query, clientConn) + + wsURL := createWebSocketURL(config.Host, path, query) + + tlsConfig, err := rest.TLSConfigFor(config) + if err != nil { + connection.updateStatus(StateError, err) + + return nil, fmt.Errorf("failed to get TLS config: %v", err) + } + + conn, err := m.dialWebSocket(wsURL, tlsConfig, config.Host) + if err != nil { + connection.updateStatus(StateError, err) + + return nil, err + } + + connection.WSConn = conn + connection.updateStatus(StateConnected, nil) + + m.mutex.Lock() + m.connections[clusterID+path] = connection + m.mutex.Unlock() + + go m.monitorConnection(connection) + + return connection, nil +} + +// getClusterConfigWithFallback attempts to get the cluster config, +// falling back to a combined key for stateless clusters. +func (m *Multiplexer) getClusterConfigWithFallback(clusterID, userID string) (*rest.Config, error) { + // Try to get config for stateful cluster first. + config, err := m.getClusterConfig(clusterID) + if err != nil { + // If not found, try with the combined key for stateless clusters. + combinedKey := fmt.Sprintf("%s%s", clusterID, userID) + + config, err = m.getClusterConfig(combinedKey) + if err != nil { + return nil, fmt.Errorf("getting cluster config: %v", err) + } + } + + return config, nil +} + +// createConnection creates a new Connection instance. +func (m *Multiplexer) createConnection( + clusterID, + userID, + path, + query string, + clientConn *websocket.Conn, +) *Connection { + return &Connection{ + ClusterID: clusterID, + UserID: userID, + Path: path, + Query: query, + Client: clientConn, + Done: make(chan struct{}), + Status: ConnectionStatus{ + State: StateConnecting, + LastMsg: time.Now(), + }, + } +} + +// dialWebSocket establishes a WebSocket connection. +func (m *Multiplexer) dialWebSocket(wsURL string, tlsConfig *tls.Config, host string) (*websocket.Conn, error) { + dialer := websocket.Dialer{ + TLSClientConfig: tlsConfig, + HandshakeTimeout: HandshakeTimeout, + } + + conn, resp, err := dialer.Dial( + wsURL, + http.Header{ + "Origin": {host}, + }, + ) + if err != nil { + logger.Log(logger.LevelError, nil, err, "dialing WebSocket") + // We only attempt to close the response body if there was an error and resp is not nil. + // In the successful case (when err is nil), the resp will actually be nil for WebSocket connections, + // so we don't need to close anything. + if resp != nil { + defer resp.Body.Close() + } + + return nil, fmt.Errorf("dialing WebSocket: %v", err) + } + + return conn, nil +} + +// monitorConnection monitors the health of a connection and attempts to reconnect if necessary. +func (m *Multiplexer) monitorConnection(conn *Connection) { + heartbeat := time.NewTicker(HeartbeatInterval) + defer heartbeat.Stop() + + for { + select { + case <-conn.Done: + conn.updateStatus(StateClosed, nil) + + return + case <-heartbeat.C: + if err := conn.WSConn.WriteMessage(websocket.PingMessage, nil); err != nil { + conn.updateStatus(StateError, fmt.Errorf("heartbeat failed: %v", err)) + + if newConn, err := m.reconnect(conn); err != nil { + logger.Log(logger.LevelError, map[string]string{"clusterID": conn.ClusterID}, err, "reconnecting to cluster") + } else { + conn = newConn + } + } + } + } +} + +// reconnect attempts to reestablish a connection. +func (m *Multiplexer) reconnect(conn *Connection) (*Connection, error) { + if conn.WSConn != nil { + conn.WSConn.Close() + } + + newConn, err := m.establishClusterConnection( + conn.ClusterID, + conn.UserID, + conn.Path, + conn.Query, + conn.Client, + ) + if err != nil { + logger.Log(logger.LevelError, map[string]string{"clusterID": conn.ClusterID}, err, "reconnecting to cluster") + + return nil, err + } + + m.mutex.Lock() + m.connections[conn.ClusterID+conn.Path] = newConn + m.mutex.Unlock() + + return newConn, nil +} + +// HandleClientWebSocket handles incoming WebSocket connections from clients. +// HandleClientWebSocket handles incoming WebSocket connections from clients. +func (m *Multiplexer) HandleClientWebSocket(w http.ResponseWriter, r *http.Request) { + clientConn, err := m.upgrader.Upgrade(w, r, nil) + if err != nil { + logger.Log(logger.LevelError, nil, err, "upgrading connection") + return + } + + defer clientConn.Close() + + for { + msg, err := m.readClientMessage(clientConn) + if err != nil { + break + } + + // Check if it's a close message + if msg.Data != nil && len(msg.Data) > 0 && string(msg.Data) == "close" { + err := m.CloseConnection(msg.ClusterID, msg.Path, msg.UserID) + if err != nil { + logger.Log( + logger.LevelError, + map[string]string{"clusterID": msg.ClusterID, "UserID": msg.UserID}, + err, + "closing connection", + ) + } + + continue + } + + conn, err := m.getOrCreateConnection(msg, clientConn) + if err != nil { + m.handleConnectionError(clientConn, msg, err) + + continue + } + + if len(msg.Data) > 0 && conn.Status.State == StateConnected { + err = m.writeMessageToCluster(conn, msg.Data) + if err != nil { + continue + } + } + } + + m.cleanupConnections() +} + +// readClientMessage reads a message from the client WebSocket connection. +func (m *Multiplexer) readClientMessage(clientConn *websocket.Conn) (Message, error) { + var msg Message + + _, rawMessage, err := clientConn.ReadMessage() + if err != nil { + logger.Log(logger.LevelError, nil, err, "reading message") + + return Message{}, err + } + + err = json.Unmarshal(rawMessage, &msg) + if err != nil { + logger.Log(logger.LevelError, nil, err, "unmarshaling message") + + return Message{}, err + } + + return msg, nil +} + +// getOrCreateConnection gets an existing connection or creates a new one if it doesn't exist. +func (m *Multiplexer) getOrCreateConnection(msg Message, clientConn *websocket.Conn) (*Connection, error) { + connKey := fmt.Sprintf("%s:%s:%s", msg.ClusterID, msg.Path, msg.UserID) + + m.mutex.RLock() + conn, exists := m.connections[connKey] + m.mutex.RUnlock() + + if !exists { + var err error + + conn, err = m.establishClusterConnection(msg.ClusterID, msg.UserID, msg.Path, msg.Query, clientConn) + if err != nil { + logger.Log( + logger.LevelError, + map[string]string{"clusterID": msg.ClusterID, "UserID": msg.UserID}, + err, + "establishing cluster connection", + ) + + return nil, err + } + + go m.handleClusterMessages(conn, clientConn) + } + + return conn, nil +} + +// handleConnectionError handles errors that occur when establishing a connection. +func (m *Multiplexer) handleConnectionError(clientConn *websocket.Conn, msg Message, err error) { + errorMsg := struct { + ClusterID string `json:"clusterId"` + Error string `json:"error"` + }{ + ClusterID: msg.ClusterID, + Error: err.Error(), + } + + if err = clientConn.WriteJSON(errorMsg); err != nil { + logger.Log( + logger.LevelError, + map[string]string{"clusterID": msg.ClusterID}, + err, + "writing error message to client", + ) + } + + logger.Log(logger.LevelError, map[string]string{"clusterID": msg.ClusterID}, err, "establishing cluster connection") +} + +// writeMessageToCluster writes a message to the cluster WebSocket connection. +func (m *Multiplexer) writeMessageToCluster(conn *Connection, data []byte) error { + err := conn.WSConn.WriteMessage(websocket.BinaryMessage, data) + if err != nil { + conn.updateStatus(StateError, err) + logger.Log( + logger.LevelError, + map[string]string{"clusterID": conn.ClusterID}, + err, + "writing message to cluster", + ) + + return err + } + + return nil +} + +// handleClusterMessages handles messages from a cluster connection. +func (m *Multiplexer) handleClusterMessages(conn *Connection, clientConn *websocket.Conn) { + defer func() { + conn.updateStatus(StateClosed, nil) + conn.WSConn.Close() + }() + + for { + select { + case <-conn.Done: + return + default: + if err := m.processClusterMessage(conn, clientConn); err != nil { + return + } + } + } +} + +// processClusterMessage processes a message from a cluster connection. +func (m *Multiplexer) processClusterMessage(conn *Connection, clientConn *websocket.Conn) error { + messageType, message, err := conn.WSConn.ReadMessage() + if err != nil { + m.handleReadError(conn, err) + + return err + } + + wrapperMsg := m.createWrapperMessage(conn, messageType, message) + + if err := clientConn.WriteJSON(wrapperMsg); err != nil { + m.handleWriteError(conn, err) + + return err + } + + conn.mu.Lock() + conn.Status.LastMsg = time.Now() + conn.mu.Unlock() + + return nil +} + +// createWrapperMessage creates a wrapper message for a cluster connection. +func (m *Multiplexer) createWrapperMessage(conn *Connection, messageType int, message []byte) struct { + ClusterID string `json:"clusterId"` + Path string `json:"path"` + Query string `json:"query"` + UserID string `json:"userId"` + Data string `json:"data"` + Binary bool `json:"binary"` +} { + wrapperMsg := struct { + ClusterID string `json:"clusterId"` + Path string `json:"path"` + Query string `json:"query"` + UserID string `json:"userId"` + Data string `json:"data"` + Binary bool `json:"binary"` + }{ + ClusterID: conn.ClusterID, + Path: conn.Path, + Query: conn.Query, + UserID: conn.UserID, + Binary: messageType == websocket.BinaryMessage, + } + + if messageType == websocket.BinaryMessage { + wrapperMsg.Data = base64.StdEncoding.EncodeToString(message) + } else { + wrapperMsg.Data = string(message) + } + + return wrapperMsg +} + +// handleReadError handles errors that occur when reading a message from a cluster connection. +func (m *Multiplexer) handleReadError(conn *Connection, err error) { + conn.updateStatus(StateError, err) + logger.Log( + logger.LevelError, + map[string]string{"clusterID": conn.ClusterID, "UserID": conn.UserID}, + err, + "reading message from cluster", + ) +} + +// handleWriteError handles errors that occur when writing a message to a client connection. +func (m *Multiplexer) handleWriteError(conn *Connection, err error) { + conn.updateStatus(StateError, err) + logger.Log( + logger.LevelError, + map[string]string{"clusterID": conn.ClusterID, "UserID": conn.UserID}, + err, + "writing message to client", + ) +} + +// cleanupConnections closes and removes all connections. +func (m *Multiplexer) cleanupConnections() { + m.mutex.Lock() + defer m.mutex.Unlock() + + for key, conn := range m.connections { + conn.updateStatus(StateClosed, nil) + close(conn.Done) + + if conn.WSConn != nil { + conn.WSConn.Close() + } + + delete(m.connections, key) + } +} + +// getClusterConfig retrieves the REST config for a given cluster. +func (m *Multiplexer) getClusterConfig(clusterID string) (*rest.Config, error) { + ctxtProxy, err := m.kubeConfigStore.GetContext(clusterID) + if err != nil { + return nil, fmt.Errorf("getting context: %v", err) + } + + clientConfig, err := ctxtProxy.RESTConfig() + if err != nil { + return nil, fmt.Errorf("getting REST config: %v", err) + } + + return clientConfig, nil +} + +// CloseConnection closes a specific connection based on its identifier. +func (m *Multiplexer) CloseConnection(clusterID, path, userID string) error { + connKey := fmt.Sprintf("%s:%s:%s", clusterID, path, userID) + + m.mutex.Lock() + defer m.mutex.Unlock() + + conn, exists := m.connections[connKey] + if !exists { + return fmt.Errorf("connection not found for key: %s", connKey) + } + + // Signal the connection to close + close(conn.Done) + + // Close the WebSocket connection + if conn.WSConn != nil { + if err := conn.WSConn.Close(); err != nil { + logger.Log( + logger.LevelError, + map[string]string{"clusterID": clusterID, "userID": userID}, + err, + "closing WebSocket connection", + ) + } + } + + // Update the connection status + conn.updateStatus(StateClosed, nil) + + // Remove the connection from the map + delete(m.connections, connKey) + + return nil +} + +// createWebSocketURL creates a WebSocket URL from the given parameters. +func createWebSocketURL(host, path, query string) string { + u, _ := url.Parse(host) + u.Scheme = "wss" + u.Path = path + u.RawQuery = query + + return u.String() +} diff --git a/backend/cmd/multiplexer_test.go b/backend/cmd/multiplexer_test.go new file mode 100644 index 0000000000..058e01377b --- /dev/null +++ b/backend/cmd/multiplexer_test.go @@ -0,0 +1,426 @@ +package main + +import ( + "crypto/tls" + "encoding/json" + "fmt" + "net" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/gorilla/websocket" + "github.com/headlamp-k8s/headlamp/backend/pkg/kubeconfig" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "k8s.io/client-go/tools/clientcmd/api" +) + +func newTestDialer() *websocket.Dialer { + return &websocket.Dialer{ + NetDial: net.Dial, + HandshakeTimeout: 45 * time.Second, + } +} + +func TestNewMultiplexer(t *testing.T) { + store := kubeconfig.NewContextStore() + m := NewMultiplexer(store) + + assert.NotNil(t, m) + assert.Equal(t, store, m.kubeConfigStore) + assert.NotNil(t, m.connections) + assert.NotNil(t, m.upgrader) +} + +func TestHandleClientWebSocket(t *testing.T) { + store := kubeconfig.NewContextStore() + m := NewMultiplexer(store) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + m.HandleClientWebSocket(w, r) + })) + defer server.Close() + + url := "ws" + strings.TrimPrefix(server.URL, "http") + + dialer := newTestDialer() + + conn, resp, err := dialer.Dial(url, nil) + if err == nil { + defer conn.Close() + } + + if resp != nil && resp.Body != nil { + defer resp.Body.Close() + } + + assert.NoError(t, err, "Should successfully establish WebSocket connection") +} + +func TestGetClusterConfigWithFallback(t *testing.T) { + store := kubeconfig.NewContextStore() + m := NewMultiplexer(store) + + // Add a mock cluster config + err := store.AddContext(&kubeconfig.Context{ + Name: "test-cluster", + Cluster: &api.Cluster{ + Server: "https://test-cluster.example.com", + }, + }) + require.NoError(t, err) + + config, err := m.getClusterConfigWithFallback("test-cluster", "test-user") + assert.NoError(t, err) + assert.NotNil(t, config) + + // Test fallback + config, err = m.getClusterConfigWithFallback("non-existent", "test-user") + assert.Error(t, err) + assert.Nil(t, config) +} + +func TestCreateConnection(t *testing.T) { + m := NewMultiplexer(kubeconfig.NewContextStore()) + clientConn, _ := createTestWebSocketConnection() + + // Add RequestID to the createConnection call + conn := m.createConnection("test-cluster", "test-user", "/api/v1/pods", "", clientConn) + assert.NotNil(t, conn) + assert.Equal(t, "test-cluster", conn.ClusterID) + assert.Equal(t, "test-user", conn.UserID) + assert.Equal(t, "/api/v1/pods", conn.Path) + assert.Equal(t, StateConnecting, conn.Status.State) +} + +func TestDialWebSocket(t *testing.T) { + m := NewMultiplexer(kubeconfig.NewContextStore()) + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + upgrader := websocket.Upgrader{ + CheckOrigin: func(r *http.Request) bool { + return true // Allow all connections for testing + }, + } + + c, err := upgrader.Upgrade(w, r, nil) + if err != nil { + t.Logf("Upgrade error: %v", err) + return + } + + defer c.Close() + // Echo incoming messages back to the client + for { + mt, message, err := c.ReadMessage() + if err != nil { + break + } + err = c.WriteMessage(mt, message) + if err != nil { + break + } + } + })) + + defer server.Close() + + wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + conn, err := m.dialWebSocket(wsURL, &tls.Config{InsecureSkipVerify: true}, server.URL) //nolint:gosec + assert.NoError(t, err) + assert.NotNil(t, conn) + + if conn != nil { + conn.Close() + } +} + +func TestMonitorConnection(t *testing.T) { + m := NewMultiplexer(kubeconfig.NewContextStore()) + clientConn, _ := createTestWebSocketConnection() + + // Updated createConnection call with all required arguments + conn := m.createConnection("test-cluster", "test-user", "/api/v1/pods", "", clientConn) + conn.WSConn, _ = createTestWebSocketConnection() + + done := make(chan struct{}) + go func() { + m.monitorConnection(conn) + close(done) + }() + + time.Sleep(100 * time.Millisecond) + close(conn.Done) + <-done + + assert.Equal(t, StateClosed, conn.Status.State) +} + +//nolint:funlen +func TestHandleClusterMessages(t *testing.T) { + m := NewMultiplexer(kubeconfig.NewContextStore()) + clientConn, clientServer := createTestWebSocketConnection() + + defer clientServer.Close() + + clusterConn, clusterServer := createTestWebSocketConnection() + + defer clusterServer.Close() + + // Add RequestID to the createConnection call + conn := m.createConnection("minikube", "test-user", "/api/v1/pods", "watch=true", clientConn) + conn.WSConn = clusterConn + + done := make(chan struct{}) + go func() { + m.handleClusterMessages(conn, clientConn) + close(done) + }() + + // Send a test message from the cluster + testMessage := []byte(`{"kind":"Pod","apiVersion":"v1","metadata":{"name":"test-pod"}}`) + err := clusterConn.WriteMessage(websocket.TextMessage, testMessage) + require.NoError(t, err) + + // Read the message from the client connection + _, receivedMessage, err := clientConn.ReadMessage() + require.NoError(t, err) + + var wrapperMsg struct { + ClusterID string `json:"clusterId"` + Path string `json:"path"` + Query string `json:"query"` + UserID string `json:"userId"` + RequestID string `json:"requestId"` + Data string `json:"data"` + Binary bool `json:"binary"` + } + + err = json.Unmarshal(receivedMessage, &wrapperMsg) + require.NoError(t, err) + + assert.Equal(t, "minikube", wrapperMsg.ClusterID) + assert.Equal(t, "/api/v1/pods", wrapperMsg.Path) + assert.Equal(t, "watch=true", wrapperMsg.Query) + assert.Equal(t, "test-user", wrapperMsg.UserID) + assert.False(t, wrapperMsg.Binary) + + // Parse the Data field separately + var podData map[string]interface{} + err = json.Unmarshal([]byte(wrapperMsg.Data), &podData) + require.NoError(t, err) + assert.Equal(t, "Pod", podData["kind"]) + assert.Equal(t, "v1", podData["apiVersion"]) + + // Close the connection to trigger the end of handleClusterMessages + conn.WSConn.Close() + + // Wait for handleClusterMessages to finish + select { + case <-done: + // Function completed successfully + case <-time.After(5 * time.Second): + t.Fatal("Test timed out") + } + + assert.Equal(t, StateClosed, conn.Status.State) +} + +func TestCleanupConnections(t *testing.T) { + m := NewMultiplexer(kubeconfig.NewContextStore()) + clientConn, _ := createTestWebSocketConnection() + // Add RequestID to the createConnection call + conn := m.createConnection("test-cluster", "test-user", "/api/v1/pods", "", clientConn) + conn.WSConn, _ = createTestWebSocketConnection() + + // Use the new connection key format + connKey := "test-cluster:/api/v1/pods:test-request-id" + m.connections[connKey] = conn + + m.cleanupConnections() + + assert.Empty(t, m.connections) + assert.Equal(t, StateClosed, conn.Status.State) +} + +func createTestWebSocketConnection() (*websocket.Conn, *httptest.Server) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + upgrader := websocket.Upgrader{} + c, err := upgrader.Upgrade(w, r, nil) + if err != nil { + return + } + + defer c.Close() + + for { + mt, message, err := c.ReadMessage() + if err != nil { + break + } + + err = c.WriteMessage(mt, message) + if err != nil { + break + } + } + })) + + wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + dialer := newTestDialer() + + ws, resp, err := dialer.Dial(wsURL, nil) + if err != nil { + panic(err) + } + + if resp != nil && resp.Body != nil { + defer resp.Body.Close() + } + + return ws, server +} + +func TestCloseConnection(t *testing.T) { + m := NewMultiplexer(kubeconfig.NewContextStore()) + clientConn, _ := createTestWebSocketConnection() + conn := m.createConnection("test-cluster", "test-user", "/api/v1/pods", "", clientConn) + conn.WSConn, _ = createTestWebSocketConnection() + + connKey := "test-cluster:/api/v1/pods:test-user" + m.connections[connKey] = conn + + err := m.CloseConnection("test-cluster", "/api/v1/pods", "test-user") + assert.NoError(t, err) + assert.Empty(t, m.connections) + assert.Equal(t, StateClosed, conn.Status.State) + + // Test closing a non-existent connection + err = m.CloseConnection("non-existent", "/api/v1/pods", "test-user") + assert.Error(t, err) +} + +func TestCreateWrapperMessage(t *testing.T) { + m := NewMultiplexer(kubeconfig.NewContextStore()) + conn := &Connection{ + ClusterID: "test-cluster", + Path: "/api/v1/pods", + Query: "watch=true", + UserID: "test-user", + } + + // Test text message + textMsg := []byte("Hello, World!") + wrapperMsg := m.createWrapperMessage(conn, websocket.TextMessage, textMsg) + assert.Equal(t, "test-cluster", wrapperMsg.ClusterID) + assert.Equal(t, "/api/v1/pods", wrapperMsg.Path) + assert.Equal(t, "watch=true", wrapperMsg.Query) + assert.Equal(t, "test-user", wrapperMsg.UserID) + assert.Equal(t, "Hello, World!", wrapperMsg.Data) + assert.False(t, wrapperMsg.Binary) + + // Test binary message + binaryMsg := []byte{0x01, 0x02, 0x03} + wrapperMsg = m.createWrapperMessage(conn, websocket.BinaryMessage, binaryMsg) + assert.Equal(t, "test-cluster", wrapperMsg.ClusterID) + assert.Equal(t, "/api/v1/pods", wrapperMsg.Path) + assert.Equal(t, "watch=true", wrapperMsg.Query) + assert.Equal(t, "test-user", wrapperMsg.UserID) + assert.Equal(t, "AQID", wrapperMsg.Data) // Base64 encoded + assert.True(t, wrapperMsg.Binary) +} + +func TestHandleConnectionError(t *testing.T) { + m := NewMultiplexer(kubeconfig.NewContextStore()) + clientConn, clientServer := createTestWebSocketConnection() + + defer clientServer.Close() + + msg := Message{ + ClusterID: "test-cluster", + Path: "/api/v1/pods", + UserID: "test-user", + } + + testError := fmt.Errorf("test error") + + // Capture the error message sent to the client + var receivedMsg struct { + ClusterID string `json:"clusterId"` + Error string `json:"error"` + } + + done := make(chan bool) + go func() { + _, rawMsg, err := clientConn.ReadMessage() + if err != nil { + t.Errorf("Error reading message: %v", err) + done <- true + + return + } + + err = json.Unmarshal(rawMsg, &receivedMsg) + if err != nil { + t.Errorf("Error unmarshaling message: %v", err) + done <- true + + return + } + + done <- true + }() + + m.handleConnectionError(clientConn, msg, testError) + + select { + case <-done: + assert.Equal(t, "test-cluster", receivedMsg.ClusterID) + assert.Equal(t, "test error", receivedMsg.Error) + case <-time.After(time.Second): + t.Fatal("Test timed out") + } +} + +func TestWriteMessageToCluster(t *testing.T) { + m := NewMultiplexer(kubeconfig.NewContextStore()) + clusterConn, clusterServer := createTestWebSocketConnection() + + defer clusterServer.Close() + + conn := &Connection{ + ClusterID: "test-cluster", + WSConn: clusterConn, + } + + testMessage := []byte("Hello, Cluster!") + + // Capture the message sent to the cluster + var receivedMessage []byte + + done := make(chan bool) + go func() { + _, receivedMessage, _ = clusterConn.ReadMessage() + done <- true + }() + + err := m.writeMessageToCluster(conn, testMessage) + assert.NoError(t, err) + + select { + case <-done: + assert.Equal(t, testMessage, receivedMessage) + case <-time.After(time.Second): + t.Fatal("Test timed out") + } + + // Test error case + clusterConn.Close() + + err = m.writeMessageToCluster(conn, testMessage) + + assert.Error(t, err) + assert.Equal(t, StateError, conn.Status.State) +} diff --git a/backend/cmd/server.go b/backend/cmd/server.go index 736e10a365..bca276986a 100644 --- a/backend/cmd/server.go +++ b/backend/cmd/server.go @@ -34,6 +34,7 @@ func main() { cache := cache.New[interface{}]() kubeConfigStore := kubeconfig.NewContextStore() + multiplexer := NewMultiplexer(kubeConfigStore) StartHeadlampServer(&HeadlampConfig{ useInCluster: conf.InCluster, @@ -53,5 +54,6 @@ func main() { enableDynamicClusters: conf.EnableDynamicClusters, cache: cache, kubeConfigStore: kubeConfigStore, + multiplexer: multiplexer, }) }