Skip to content

Commit

Permalink
[fix] ws and tcp channel stability
Browse files Browse the repository at this point in the history
  • Loading branch information
Musixal committed Sep 17, 2024
1 parent 2027384 commit 4045b31
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 25 deletions.
1 change: 0 additions & 1 deletion internal/client/transport/tcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,6 @@ connectLoop:
for {
select {
case <-c.ctx.Done():
go c.closeControlChannel("context cancellation")
return
default:
tunnelTCPConn, err := c.tcpDialer(c.config.RemoteAddr, c.config.Nodelay)
Expand Down
16 changes: 15 additions & 1 deletion internal/client/transport/ws.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ func (c *WsTransport) Restart() {
c.cancel()
}

go c.closeControlChannel("restarting client")

time.Sleep(2 * time.Second)

ctx, cancel := context.WithCancel(c.parentctx)
Expand All @@ -93,6 +95,14 @@ func (c *WsTransport) Restart() {

}

func (c *WsTransport) closeControlChannel(reason string) {
if c.controlChannel != nil {
_ = c.controlChannel.WriteMessage(websocket.TextMessage, []byte("closed"))
c.controlChannel.Close()
c.logger.Debugf("control channel closed due to %s", reason)
}
}

func (c *WsTransport) ChannelDialer() {
// for webui
if c.config.WebPort > 0 {
Expand All @@ -101,6 +111,7 @@ func (c *WsTransport) ChannelDialer() {

c.config.TunnelStatus = "Disconnected (Websocket)"

connectLoop:
for {
select {
case <-c.ctx.Done():
Expand All @@ -121,9 +132,12 @@ func (c *WsTransport) ChannelDialer() {

go c.channelListener()

return
break connectLoop
}
}

<-c.ctx.Done()
go c.closeControlChannel("context cancellation")
}

func (c *WsTransport) channelListener() {
Expand Down
26 changes: 3 additions & 23 deletions internal/server/transport/tcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ func NewTCPServer(parentCtx context.Context, config *TcpConfig, logger *logrus.L
tunnelChannel: make(chan net.Conn, config.ChannelSize),
getNewConnChan: make(chan struct{}, config.ChannelSize),
controlChannel: nil, // will be set when a control connection is established
timeout: 3 * time.Second, // Default timeout for waiting for a tunnel connection
timeout: 5 * time.Second, // Default timeout for waiting for a tunnel connection
heartbeatDuration: time.Duration(config.Heartbeat) * time.Second, // Heartbeat duration
heartbeatSig: "0", // Default heartbeat signal
chanSignal: "1", // Default channel signal
Expand Down Expand Up @@ -498,32 +498,12 @@ func (s *TcpTransport) handleTCPSession(remotePort int, acceptChan chan net.Conn
return

case <-s.ctx.Done():
for {
select {
case conn := <-acceptChan:
if conn != nil {
conn.Close()
s.logger.Trace("existing local connections have been closed.")
}
default:
return
}
}
return

}
}
case <-s.ctx.Done():
for {
select {
case conn := <-acceptChan:
if conn != nil {
conn.Close()
s.logger.Trace("existing local connections have been closed.")
}
default:
return
}
}
return

}

Expand Down
59 changes: 59 additions & 0 deletions internal/server/transport/ws.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,9 @@ func (s *WsTransport) Restart() {
s.cancel()
}

// Close any open connections in the tunnel channel.
go s.cleanupConnections()

time.Sleep(2 * time.Second)

ctx, cancel := context.WithCancel(s.parentctx)
Expand All @@ -111,6 +114,61 @@ func (s *WsTransport) Restart() {
go s.TunnelListener()

}

// cleanupConnections closes all active connections in the tunnel channel.
func (s *WsTransport) cleanupConnections() {
if s.controlChannel != nil {
s.logger.Debug("control channel have been closed.")
s.controlChannel.Close()
}
for {
select {
case conn := <-s.tunnelChannel:
if conn.conn != nil {
conn.conn.Close()
s.logger.Trace("existing tunnel connections have been closed.")
}
default:
return
}
}
}

func (s *WsTransport) getClosedSignal() {
for {
// Channel to receive the message or error
resultChan := make(chan struct {
message []byte
err error
})
go func() {
_, message, err := s.controlChannel.ReadMessage()
resultChan <- struct {
message []byte
err error
}{message, err}
}()

select {
case <-s.ctx.Done():
return

case result := <-resultChan:
if result.err != nil {
s.logger.Errorf("failed to receive message from tunnel connection: %v", result.err)
go s.Restart()
return
}
if string(result.message) == "closed" {
s.logger.Info("control channel has been closed by the client")
go s.Restart()
return
}
}
}

}

func (s *WsTransport) portConfigReader() {
// port mapping for listening on each local port
for _, portMapping := range s.config.Ports {
Expand Down Expand Up @@ -260,6 +318,7 @@ func (s *WsTransport) TunnelListener() {
go s.heartbeat()
go s.poolChecker()
go s.portConfigReader()
go s.getClosedSignal()

s.config.TunnelStatus = "Connected (Websocket)"

Expand Down

0 comments on commit 4045b31

Please sign in to comment.