diff --git a/internal/client/transport/tcp.go b/internal/client/transport/tcp.go index 2eed609..c131b9e 100644 --- a/internal/client/transport/tcp.go +++ b/internal/client/transport/tcp.go @@ -102,7 +102,6 @@ connectLoop: for { select { case <-c.ctx.Done(): - go c.closeControlChannel("context cancellation") return default: tunnelTCPConn, err := c.tcpDialer(c.config.RemoteAddr, c.config.Nodelay) diff --git a/internal/client/transport/ws.go b/internal/client/transport/ws.go index 9a6744d..2be901e 100644 --- a/internal/client/transport/ws.go +++ b/internal/client/transport/ws.go @@ -78,6 +78,8 @@ func (c *WsTransport) Restart() { c.cancel() } + go c.closeControlChannel("restarting client") + time.Sleep(2 * time.Second) ctx, cancel := context.WithCancel(c.parentctx) @@ -93,6 +95,14 @@ func (c *WsTransport) Restart() { } +func (c *WsTransport) closeControlChannel(reason string) { + if c.controlChannel != nil { + _ = c.controlChannel.WriteMessage(websocket.TextMessage, []byte("closed")) + c.controlChannel.Close() + c.logger.Debugf("control channel closed due to %s", reason) + } +} + func (c *WsTransport) ChannelDialer() { // for webui if c.config.WebPort > 0 { @@ -101,6 +111,7 @@ func (c *WsTransport) ChannelDialer() { c.config.TunnelStatus = "Disconnected (Websocket)" +connectLoop: for { select { case <-c.ctx.Done(): @@ -121,9 +132,12 @@ func (c *WsTransport) ChannelDialer() { go c.channelListener() - return + break connectLoop } } + + <-c.ctx.Done() + go c.closeControlChannel("context cancellation") } func (c *WsTransport) channelListener() { diff --git a/internal/server/transport/tcp.go b/internal/server/transport/tcp.go index 5950d35..9e47e3c 100644 --- a/internal/server/transport/tcp.go +++ b/internal/server/transport/tcp.go @@ -61,7 +61,7 @@ func NewTCPServer(parentCtx context.Context, config *TcpConfig, logger *logrus.L tunnelChannel: make(chan net.Conn, config.ChannelSize), getNewConnChan: make(chan struct{}, config.ChannelSize), controlChannel: nil, // will be set when a control connection is established - timeout: 3 * time.Second, // Default timeout for waiting for a tunnel connection + timeout: 5 * time.Second, // Default timeout for waiting for a tunnel connection heartbeatDuration: time.Duration(config.Heartbeat) * time.Second, // Heartbeat duration heartbeatSig: "0", // Default heartbeat signal chanSignal: "1", // Default channel signal @@ -498,32 +498,12 @@ func (s *TcpTransport) handleTCPSession(remotePort int, acceptChan chan net.Conn return case <-s.ctx.Done(): - for { - select { - case conn := <-acceptChan: - if conn != nil { - conn.Close() - s.logger.Trace("existing local connections have been closed.") - } - default: - return - } - } + return } } case <-s.ctx.Done(): - for { - select { - case conn := <-acceptChan: - if conn != nil { - conn.Close() - s.logger.Trace("existing local connections have been closed.") - } - default: - return - } - } + return } diff --git a/internal/server/transport/ws.go b/internal/server/transport/ws.go index d4f29b5..57a59c4 100644 --- a/internal/server/transport/ws.go +++ b/internal/server/transport/ws.go @@ -95,6 +95,9 @@ func (s *WsTransport) Restart() { s.cancel() } + // Close any open connections in the tunnel channel. + go s.cleanupConnections() + time.Sleep(2 * time.Second) ctx, cancel := context.WithCancel(s.parentctx) @@ -111,6 +114,61 @@ func (s *WsTransport) Restart() { go s.TunnelListener() } + +// cleanupConnections closes all active connections in the tunnel channel. +func (s *WsTransport) cleanupConnections() { + if s.controlChannel != nil { + s.logger.Debug("control channel have been closed.") + s.controlChannel.Close() + } + for { + select { + case conn := <-s.tunnelChannel: + if conn.conn != nil { + conn.conn.Close() + s.logger.Trace("existing tunnel connections have been closed.") + } + default: + return + } + } +} + +func (s *WsTransport) getClosedSignal() { + for { + // Channel to receive the message or error + resultChan := make(chan struct { + message []byte + err error + }) + go func() { + _, message, err := s.controlChannel.ReadMessage() + resultChan <- struct { + message []byte + err error + }{message, err} + }() + + select { + case <-s.ctx.Done(): + return + + case result := <-resultChan: + if result.err != nil { + s.logger.Errorf("failed to receive message from tunnel connection: %v", result.err) + go s.Restart() + return + } + if string(result.message) == "closed" { + s.logger.Info("control channel has been closed by the client") + go s.Restart() + return + } + } + } + +} + func (s *WsTransport) portConfigReader() { // port mapping for listening on each local port for _, portMapping := range s.config.Ports { @@ -260,6 +318,7 @@ func (s *WsTransport) TunnelListener() { go s.heartbeat() go s.poolChecker() go s.portConfigReader() + go s.getClosedSignal() s.config.TunnelStatus = "Connected (Websocket)"