diff --git a/cmd/agent/app/server.go b/cmd/agent/app/server.go index 8cc3cfd0e..8c087de53 100644 --- a/cmd/agent/app/server.go +++ b/cmd/agent/app/server.go @@ -24,10 +24,13 @@ import ( "net" "net/http" "net/http/pprof" + "os" + "os/signal" "runtime" runpprof "runtime/pprof" "strconv" "strings" + "syscall" "time" "github.com/prometheus/client_golang/prometheus/promhttp" @@ -49,8 +52,8 @@ func NewAgentCommand(a *Agent, o *options.GrpcProxyAgentOptions) *cobra.Command Use: "agent", Long: `A gRPC agent, Connects to the proxy and then allows traffic to be forwarded to it.`, RunE: func(cmd *cobra.Command, args []string) error { - stopCh := make(chan struct{}) - return a.Run(o, stopCh) + drainCh, stopCh := SetupSignalHandler() + return a.Run(o, drainCh, stopCh) }, } @@ -64,13 +67,13 @@ type Agent struct { cs *agent.ClientSet } -func (a *Agent) Run(o *options.GrpcProxyAgentOptions, stopCh <-chan struct{}) error { +func (a *Agent) Run(o *options.GrpcProxyAgentOptions, drainCh, stopCh <-chan struct{}) error { o.Print() if err := o.Validate(); err != nil { return fmt.Errorf("failed to validate agent options with %v", err) } - cs, err := a.runProxyConnection(o, stopCh) + cs, err := a.runProxyConnection(o, drainCh, stopCh) if err != nil { return fmt.Errorf("failed to run proxy connection with %v", err) } @@ -92,7 +95,31 @@ func (a *Agent) Run(o *options.GrpcProxyAgentOptions, stopCh <-chan struct{}) er return nil } -func (a *Agent) runProxyConnection(o *options.GrpcProxyAgentOptions, stopCh <-chan struct{}) (*agent.ClientSet, error) { +var shutdownSignals = []os.Signal{os.Interrupt, syscall.SIGTERM} + +func SetupSignalHandler() (drainCh, stopCh <-chan struct{}) { + drain := make(chan struct{}) + stop := make(chan struct{}) + c := make(chan os.Signal, 2) + signal.Notify(c, shutdownSignals...) + labels := runpprof.Labels( + "core", "signalHandler", + ) + go runpprof.Do(context.Background(), labels, func(context.Context) { handleSignals(c, drain, stop) }) + + return drain, stop +} + +func handleSignals(signalCh chan os.Signal, drainCh, stopCh chan struct{}) { + s := <-signalCh + klog.V(2).InfoS("Received first signal", "signal", s) + close(drainCh) + s = <-signalCh + klog.V(2).InfoS("Received second signal", "signal", s) + close(stopCh) +} + +func (a *Agent) runProxyConnection(o *options.GrpcProxyAgentOptions, drainCh, stopCh <-chan struct{}) (*agent.ClientSet, error) { var tlsConfig *tls.Config var err error if tlsConfig, err = util.GetClientTLSConfig(o.CaCert, o.AgentCert, o.AgentKey, o.ProxyServerHost, o.AlpnProtos); err != nil { @@ -106,7 +133,7 @@ func (a *Agent) runProxyConnection(o *options.GrpcProxyAgentOptions, stopCh <-ch }), } cc := o.ClientSetConfig(dialOptions...) - cs := cc.NewAgentClientSet(stopCh) + cs := cc.NewAgentClientSet(drainCh, stopCh) cs.Serve() return cs, nil diff --git a/pkg/agent/client.go b/pkg/agent/client.go index 16b1df244..f843d1538 100644 --- a/pkg/agent/client.go +++ b/pkg/agent/client.go @@ -137,7 +137,11 @@ type Client struct { address string opts []grpc.DialOption conn *grpc.ClientConn - stopCh chan struct{} + + drainCh <-chan struct{} + drainOnce sync.Once + stopCh chan struct{} + // locks sendLock sync.Mutex recvLock sync.Mutex @@ -158,6 +162,7 @@ func newAgentClient(address, agentID, agentIdentifiers string, cs *ClientSet, op agentIdentifiers: agentIdentifiers, opts: opts, probeInterval: cs.probeInterval, + drainCh: cs.drainCh, stopCh: make(chan struct{}), serviceAccountTokenPath: cs.serviceAccountTokenPath, connManager: newConnectionManager(), @@ -325,6 +330,19 @@ func (a *Client) Serve() { case <-a.stopCh: klog.V(2).InfoS("stop agent client.") return + case <-a.drainCh: + a.drainOnce.Do(func() { + klog.V(2).InfoS("drain agent client.") + drainPkt := &client.Packet{ + Type: client.PacketType_DRAIN, + Payload: &client.Packet_Drain{ + Drain: &client.Drain{}, + }, + } + if err := a.Send(drainPkt); err != nil { + klog.ErrorS(err, "drain failure", "") + } + }) default: } diff --git a/pkg/agent/client_test.go b/pkg/agent/client_test.go index 87b61b1b9..6e69ebbcd 100644 --- a/pkg/agent/client_test.go +++ b/pkg/agent/client_test.go @@ -343,6 +343,40 @@ func TestFailedSend_DialResp_GRPC(t *testing.T) { }() } +func TestDrain(t *testing.T) { + var stream agent.AgentService_ConnectClient + drainCh := make(chan struct{}) + stopCh := make(chan struct{}) + cs := &ClientSet{ + clients: make(map[string]*Client), + drainCh: drainCh, + stopCh: stopCh, + } + testClient := &Client{ + connManager: newConnectionManager(), + drainCh: drainCh, + stopCh: stopCh, + cs: cs, + } + testClient.stream, stream = pipe() + + // Start agent + go testClient.Serve() + defer close(stopCh) + + // Simulate pod first shutdown signal + close(drainCh) + + // Expect to receive DRAIN packet from (Agent) Client + pkt, err := stream.Recv() + if err != nil { + t.Fatal(err) + } + if pkt.Type != client.PacketType_DRAIN { + t.Errorf("expect PacketType_DRAIN; got %v", pkt.Type) + } +} + // fakeStream implements AgentService_ConnectClient type fakeStream struct { grpc.ClientStream diff --git a/pkg/agent/clientset.go b/pkg/agent/clientset.go index c5adcf4f7..6a2510a84 100644 --- a/pkg/agent/clientset.go +++ b/pkg/agent/clientset.go @@ -52,6 +52,8 @@ type ClientSet struct { dialOptions []grpc.DialOption // file path contains service account token serviceAccountTokenPath string + // channel to signal that the agent is pending termination. + drainCh <-chan struct{} // channel to signal shutting down the client set. Primarily for test. stopCh <-chan struct{} @@ -141,7 +143,7 @@ type ClientSetConfig struct { SyncForever bool } -func (cc *ClientSetConfig) NewAgentClientSet(stopCh <-chan struct{}) *ClientSet { +func (cc *ClientSetConfig) NewAgentClientSet(drainCh, stopCh <-chan struct{}) *ClientSet { return &ClientSet{ clients: make(map[string]*Client), agentID: cc.AgentID, @@ -154,6 +156,7 @@ func (cc *ClientSetConfig) NewAgentClientSet(stopCh <-chan struct{}) *ClientSet serviceAccountTokenPath: cc.ServiceAccountTokenPath, warnOnChannelLimit: cc.WarnOnChannelLimit, syncForever: cc.SyncForever, + drainCh: drainCh, stopCh: stopCh, } } diff --git a/pkg/server/server.go b/pkg/server/server.go index d65d5c4be..b3f883c93 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -990,6 +990,8 @@ func (s *ProxyServer) serveRecvBackend(backend Backend, agentID string, recvCh < klog.V(5).InfoS("CLOSE_RSP sent to frontend", "connectionID", resp.ConnectID) } + case client.PacketType_DRAIN: + klog.V(2).InfoS("agent is draining", "agentID", agentID) default: klog.V(5).InfoS("Ignoring unrecognized packet from backend", "packet", pkt, "agentID", agentID) } diff --git a/tests/framework/agent.go b/tests/framework/agent.go index 15b9ae2e2..6d446fca6 100644 --- a/tests/framework/agent.go +++ b/tests/framework/agent.go @@ -27,6 +27,7 @@ import ( "path/filepath" "strconv" "sync" + "syscall" "testing" "time" @@ -53,6 +54,7 @@ type AgentRunner interface { type Agent interface { GetConnectedServerCount() (int, error) Ready() bool + Drain() Stop() Metrics() metricstest.AgentTester } @@ -66,9 +68,10 @@ func (*InProcessAgentRunner) Start(t testing.TB, opts AgentOpts) (Agent, error) } ctx, cancel := context.WithCancel(context.Background()) + drainCh := make(chan struct{}) stopCh := make(chan struct{}) go func() { - if err := a.Run(o, stopCh); err != nil { + if err := a.Run(o, drainCh, stopCh); err != nil { log.Printf("ERROR running agent: %v", err) cancel() } @@ -84,6 +87,7 @@ func (*InProcessAgentRunner) Start(t testing.TB, opts AgentOpts) (Agent, error) pa := &inProcessAgent{ client: a.ClientSet(), + drainCh: drainCh, stopCh: stopCh, healthAddr: healthAddr, } @@ -94,12 +98,21 @@ func (*InProcessAgentRunner) Start(t testing.TB, opts AgentOpts) (Agent, error) type inProcessAgent struct { client *agent.ClientSet + drainOnce sync.Once + drainCh chan struct{} + stopOnce sync.Once stopCh chan struct{} healthAddr string } +func (a *inProcessAgent) Drain() { + a.drainOnce.Do(func() { + close(a.drainCh) + }) +} + func (a *inProcessAgent) Stop() { a.stopOnce.Do(func() { close(a.stopCh) @@ -160,7 +173,16 @@ type externalAgent struct { cmd *exec.Cmd metrics *metricstest.Tester - stopOnce sync.Once + drainOnce sync.Once + stopOnce sync.Once +} + +func (a *externalAgent) Drain() { + a.drainOnce.Do(func() { + if err := a.cmd.Process.Signal(syscall.SIGTERM); err != nil { + log.Fatalf("Error draining agent process: %v", err) + } + }) } func (a *externalAgent) Stop() {