Skip to content

Commit

Permalink
Merge pull request #2741 from headlamp-k8s/web-lock-fix
Browse files Browse the repository at this point in the history
backend: enhance WSConnLock with thread-safety
  • Loading branch information
illume authored Jan 10, 2025
2 parents a0806b2 + 19cd3e6 commit d290009
Show file tree
Hide file tree
Showing 2 changed files with 473 additions and 290 deletions.
88 changes: 76 additions & 12 deletions backend/cmd/multiplexer.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ type Connection struct {
// Status is the status of the connection.
Status ConnectionStatus
// Client is the WebSocket connection to the client.
Client *websocket.Conn
Client *WSConnLock
// Done is a channel to signal when the connection is done.
Done chan struct{}
// mu is a mutex to synchronize access to the connection.
Expand Down Expand Up @@ -104,6 +104,68 @@ type Multiplexer struct {
kubeConfigStore kubeconfig.ContextStore
}

// WSConnLock provides a thread-safe wrapper around a WebSocket connection.
// It ensures that write operations are synchronized using a mutex to prevent
// concurrent writes which could corrupt the WebSocket stream.
type WSConnLock struct {
// conn is the underlying WebSocket connection
conn *websocket.Conn
// writeMu is a mutex to synchronize access to write operations.
// This prevents concurrent writes to the WebSocket connection.
writeMu sync.Mutex
}

// NewWSConnLock creates a new WSConnLock instance that wraps the provided
// WebSocket connection with thread-safe write operations.
func NewWSConnLock(conn *websocket.Conn) *WSConnLock {
return &WSConnLock{
conn: conn,
writeMu: sync.Mutex{},
}
}

// WriteJSON writes the JSON encoding of v as a message to the WebSocket connection.
// It ensures thread-safety by using a mutex lock during the write operation.
func (conn *WSConnLock) WriteJSON(v interface{}) error {
conn.writeMu.Lock()
defer conn.writeMu.Unlock()

return conn.conn.WriteJSON(v)
}

// ReadJSON reads the next JSON-encoded message from the WebSocket connection
// and stores it in the value pointed to by v.
// Note: Reading is already thread-safe in gorilla/websocket, so no mutex is needed.
func (conn *WSConnLock) ReadJSON(v interface{}) error {
return conn.conn.ReadJSON(v)
}

// ReadMessage reads the next message from the WebSocket connection.
// It returns the message type and payload.
// Note: Reading is already thread-safe in gorilla/websocket, so no mutex is needed.
func (conn *WSConnLock) ReadMessage() (messageType int, p []byte, err error) {
return conn.conn.ReadMessage()
}

// WriteMessage writes a message to the WebSocket connection with the given type and payload.
// It ensures thread-safety by using a mutex lock during the write operation.
func (conn *WSConnLock) WriteMessage(messageType int, data []byte) error {
conn.writeMu.Lock()
defer conn.writeMu.Unlock()

return conn.conn.WriteMessage(messageType, data)
}

// Close safely closes the WebSocket connection.
// It ensures thread-safety by acquiring the write mutex before closing,
// preventing any concurrent writes during the close operation.
func (conn *WSConnLock) Close() error {
conn.writeMu.Lock()
defer conn.writeMu.Unlock()

return conn.conn.Close()
}

// NewMultiplexer creates a new Multiplexer instance.
func NewMultiplexer(kubeConfigStore kubeconfig.ContextStore) *Multiplexer {
return &Multiplexer{
Expand Down Expand Up @@ -183,7 +245,7 @@ func (m *Multiplexer) establishClusterConnection(
userID,
path,
query string,
clientConn *websocket.Conn,
clientConn *WSConnLock,
) (*Connection, error) {
config, err := m.getClusterConfigWithFallback(clusterID, userID)
if err != nil {
Expand Down Expand Up @@ -246,7 +308,7 @@ func (m *Multiplexer) createConnection(
userID,
path,
query string,
clientConn *websocket.Conn,
clientConn *WSConnLock,
) *Connection {
return &Connection{
ClusterID: clusterID,
Expand Down Expand Up @@ -355,6 +417,8 @@ func (m *Multiplexer) HandleClientWebSocket(w http.ResponseWriter, r *http.Reque

defer clientConn.Close()

lockClientConn := NewWSConnLock(clientConn)

for {
msg, err := m.readClientMessage(clientConn)
if err != nil {
Expand All @@ -368,9 +432,9 @@ func (m *Multiplexer) HandleClientWebSocket(w http.ResponseWriter, r *http.Reque
continue
}

conn, err := m.getOrCreateConnection(msg, clientConn)
conn, err := m.getOrCreateConnection(msg, lockClientConn)
if err != nil {
m.handleConnectionError(clientConn, msg, err)
m.handleConnectionError(lockClientConn, msg, err)

continue
}
Expand Down Expand Up @@ -408,7 +472,7 @@ func (m *Multiplexer) readClientMessage(clientConn *websocket.Conn) (Message, er
}

// getOrCreateConnection gets an existing connection or creates a new one if it doesn't exist.
func (m *Multiplexer) getOrCreateConnection(msg Message, clientConn *websocket.Conn) (*Connection, error) {
func (m *Multiplexer) getOrCreateConnection(msg Message, clientConn *WSConnLock) (*Connection, error) {
connKey := m.createConnectionKey(msg.ClusterID, msg.Path, msg.UserID)

m.mutex.RLock()
Expand Down Expand Up @@ -437,7 +501,7 @@ func (m *Multiplexer) getOrCreateConnection(msg Message, clientConn *websocket.C
}

// handleConnectionError handles errors that occur when establishing a connection.
func (m *Multiplexer) handleConnectionError(clientConn *websocket.Conn, msg Message, err error) {
func (m *Multiplexer) handleConnectionError(clientConn *WSConnLock, msg Message, err error) {
errorMsg := struct {
ClusterID string `json:"clusterId"`
Error string `json:"error"`
Expand Down Expand Up @@ -477,7 +541,7 @@ func (m *Multiplexer) writeMessageToCluster(conn *Connection, data []byte) error
}

// handleClusterMessages handles messages from a cluster connection.
func (m *Multiplexer) handleClusterMessages(conn *Connection, clientConn *websocket.Conn) {
func (m *Multiplexer) handleClusterMessages(conn *Connection, clientConn *WSConnLock) {
defer m.cleanupConnection(conn)

var lastResourceVersion string
Expand All @@ -497,7 +561,7 @@ func (m *Multiplexer) handleClusterMessages(conn *Connection, clientConn *websoc
// processClusterMessage processes a single message from the cluster.
func (m *Multiplexer) processClusterMessage(
conn *Connection,
clientConn *websocket.Conn,
clientConn *WSConnLock,
lastResourceVersion *string,
) error {
messageType, message, err := conn.WSConn.ReadMessage()
Expand Down Expand Up @@ -541,7 +605,7 @@ func (m *Multiplexer) processClusterMessage(
func (m *Multiplexer) sendIfNewResourceVersion(
message []byte,
conn *Connection,
clientConn *websocket.Conn,
clientConn *WSConnLock,
lastResourceVersion *string,
) error {
var obj map[string]interface{}
Expand Down Expand Up @@ -581,7 +645,7 @@ func (m *Multiplexer) sendIfNewResourceVersion(
}

// sendCompleteMessage sends a COMPLETE message to the client.
func (m *Multiplexer) sendCompleteMessage(conn *Connection, clientConn *websocket.Conn) error {
func (m *Multiplexer) sendCompleteMessage(conn *Connection, clientConn *WSConnLock) error {
conn.mu.RLock()
if conn.closed {
conn.mu.RUnlock()
Expand Down Expand Up @@ -614,7 +678,7 @@ func (m *Multiplexer) sendCompleteMessage(conn *Connection, clientConn *websocke
// sendDataMessage sends the actual data message to the client.
func (m *Multiplexer) sendDataMessage(
conn *Connection,
clientConn *websocket.Conn,
clientConn *WSConnLock,
messageType int,
message []byte,
) error {
Expand Down
Loading

0 comments on commit d290009

Please sign in to comment.