diff --git a/konnectivity-client/pkg/client/client.go b/konnectivity-client/pkg/client/client.go index 3d7033b31..e181a4344 100644 --- a/konnectivity-client/pkg/client/client.go +++ b/konnectivity-client/pkg/client/client.go @@ -334,11 +334,23 @@ func (t *grpcTunnel) serve(tunnelCtx context.Context) { case client.PacketType_DATA: resp := pkt.GetData() + if resp.ConnectID == 0 { + klog.ErrorS(nil, "Received packet missing ConnectID", "packetType", "DATA") + continue + } // TODO: flow control conn, ok := t.conns.get(resp.ConnectID) if !ok { - klog.V(1).InfoS("Connection not recognized", "connectionID", resp.ConnectID) + klog.ErrorS(nil, "Connection not recognized", "connectionID", resp.ConnectID, "packetType", "DATA") + t.Send(&client.Packet{ + Type: client.PacketType_CLOSE_REQ, + Payload: &client.Packet_CloseRequest{ + CloseRequest: &client.CloseRequest{ + ConnectID: resp.ConnectID, + }, + }, + }) continue } timer := time.NewTimer((time.Duration)(t.readTimeoutSeconds) * time.Second) @@ -357,7 +369,7 @@ func (t *grpcTunnel) serve(tunnelCtx context.Context) { conn, ok := t.conns.get(resp.ConnectID) if !ok { - klog.V(1).InfoS("Connection not recognized", "connectionID", resp.ConnectID) + klog.V(1).InfoS("Connection not recognized", "connectionID", resp.ConnectID, "packetType", "CLOSE_RSP") continue } close(conn.readCh) diff --git a/pkg/agent/client.go b/pkg/agent/client.go index fcd58705d..383533e44 100644 --- a/pkg/agent/client.go +++ b/pkg/agent/client.go @@ -497,10 +497,26 @@ func (a *Client) Serve() { case client.PacketType_DATA: data := pkt.GetData() klog.V(4).InfoS("received DATA", "connectionID", data.ConnectID) + if data.ConnectID == 0 { + klog.ErrorS(nil, "Received packet missing ConnectID from frontend", "packetType", "DATA") + continue + } ctx, ok := a.connManager.Get(data.ConnectID) if ok { ctx.send(data.Data) + } else { + klog.V(2).InfoS("received DATA for unrecognized connection", "connectionID", data.ConnectID) + a.Send(&client.Packet{ + Type: client.PacketType_CLOSE_RSP, + Payload: &client.Packet_CloseResponse{ + CloseResponse: &client.CloseResponse{ + ConnectID: data.ConnectID, + Error: "unrecognized connectID", + }, + }, + }) + continue } case client.PacketType_CLOSE_REQ: @@ -590,9 +606,14 @@ func (a *Client) proxyToRemote(connID int64, ctx *connContext) { // As the read side of the dataCh channel, we cannot close it. // However serve() may be blocked writing to the channel, // so we need to consume the channel until it is closed. + discardedPktCount := 0 for range ctx.dataCh { // Ignore values as this indicates there was a problem // with the remote connection. + discardedPktCount++ + } + if discardedPktCount > 0 { + klog.V(2).InfoS("Discard packets while exiting proxyToRemote", "pktCount", discardedPktCount, "connectionID", connID) } }() diff --git a/pkg/agent/client_test.go b/pkg/agent/client_test.go index bc4b7b5e6..0528d06a8 100644 --- a/pkg/agent/client_test.go +++ b/pkg/agent/client_test.go @@ -59,7 +59,7 @@ func TestServeData_HTTP(t *testing.T) { })) defer ts.Close() - // Stimulate sending KAS DIAL_REQ to (Agent) Client + // Simulate sending KAS DIAL_REQ to (Agent) Client dialPacket := newDialPacket("tcp", ts.URL[len("http://"):], 111) err = stream.Send(dialPacket) if err != nil { @@ -67,17 +67,17 @@ func TestServeData_HTTP(t *testing.T) { } // Expect receiving DIAL_RSP packet from (Agent) Client - pkg, err := stream.Recv() + pkt, err := stream.Recv() if err != nil { t.Fatal(err.Error()) } - if pkg == nil { + if pkt == nil { t.Fatal("unexpected nil packet") } - if pkg.Type != client.PacketType_DIAL_RSP { - t.Errorf("expect PacketType_DIAL_RSP; got %v", pkg.Type) + if pkt.Type != client.PacketType_DIAL_RSP { + t.Errorf("expect PacketType_DIAL_RSP; got %v", pkt.Type) } - dialRsp := pkg.Payload.(*client.Packet_DialResponse) + dialRsp := pkt.Payload.(*client.Packet_DialResponse) connID := dialRsp.DialResponse.ConnectID if dialRsp.DialResponse.Random != 111 { t.Errorf("expect random=111; got %v", dialRsp.DialResponse.Random) @@ -91,14 +91,14 @@ func TestServeData_HTTP(t *testing.T) { } // Expect receiving http response via (Agent) Client - pkg, _ = stream.Recv() - if pkg == nil { + pkt, _ = stream.Recv() + if pkt == nil { t.Fatal("unexpected nil packet") } - if pkg.Type != client.PacketType_DATA { - t.Errorf("expect PacketType_DATA; got %v", pkg.Type) + if pkt.Type != client.PacketType_DATA { + t.Errorf("expect PacketType_DATA; got %v", pkt.Type) } - data := pkg.Payload.(*client.Packet_Data).Data.Data + data := pkt.Payload.(*client.Packet_Data).Data.Data // Verify response data // @@ -117,14 +117,14 @@ func TestServeData_HTTP(t *testing.T) { ts.Close() // Verify receiving CLOSE_RSP - pkg, _ = stream.Recv() - if pkg == nil { + pkt, _ = stream.Recv() + if pkt == nil { t.Fatal("unexpected nil packet") } - if pkg.Type != client.PacketType_CLOSE_RSP { - t.Errorf("expect PacketType_CLOSE_RSP; got %v", pkg.Type) + if pkt.Type != client.PacketType_CLOSE_RSP { + t.Errorf("expect PacketType_CLOSE_RSP; got %v", pkt.Type) } - closeErr := pkg.Payload.(*client.Packet_CloseResponse).CloseResponse.Error + closeErr := pkt.Payload.(*client.Packet_CloseResponse).CloseResponse.Error if closeErr != "" { t.Errorf("expect nil closeErr; got %v", closeErr) } @@ -159,7 +159,7 @@ func TestClose_Client(t *testing.T) { })) defer ts.Close() - // Stimulate sending KAS DIAL_REQ to (Agent) Client + // Simulate sending KAS DIAL_REQ to (Agent) Client dialPacket := newDialPacket("tcp", ts.URL[len("http://"):], 111) err := stream.Send(dialPacket) if err != nil { @@ -167,14 +167,14 @@ func TestClose_Client(t *testing.T) { } // Expect receiving DIAL_RSP packet from (Agent) Client - pkg, _ := stream.Recv() - if pkg == nil { + pkt, _ := stream.Recv() + if pkt == nil { t.Fatal("unexpected nil packet") } - if pkg.Type != client.PacketType_DIAL_RSP { - t.Errorf("expect PacketType_DIAL_RSP; got %v", pkg.Type) + if pkt.Type != client.PacketType_DIAL_RSP { + t.Errorf("expect PacketType_DIAL_RSP; got %v", pkt.Type) } - dialRsp := pkg.Payload.(*client.Packet_DialResponse) + dialRsp := pkt.Payload.(*client.Packet_DialResponse) connID := dialRsp.DialResponse.ConnectID if dialRsp.DialResponse.Random != 111 { t.Errorf("expect random=111; got %v", dialRsp.DialResponse.Random) @@ -186,14 +186,14 @@ func TestClose_Client(t *testing.T) { } // Expect receiving close response via (Agent) Client - pkg, _ = stream.Recv() - if pkg == nil { + pkt, _ = stream.Recv() + if pkt == nil { t.Error("unexpected nil packet") } - if pkg.Type != client.PacketType_CLOSE_RSP { - t.Errorf("expect PacketType_CLOSE_RSP; got %v", pkg.Type) + if pkt.Type != client.PacketType_CLOSE_RSP { + t.Errorf("expect PacketType_CLOSE_RSP; got %v", pkt.Type) } - closeErr := pkg.Payload.(*client.Packet_CloseResponse).CloseResponse.Error + closeErr := pkt.Payload.(*client.Packet_CloseResponse).CloseResponse.Error if closeErr != "" { t.Errorf("expect nil closeErr; got %v", closeErr) } @@ -209,20 +209,59 @@ func TestClose_Client(t *testing.T) { } // Expect receiving close response via (Agent) Client - pkg, _ = stream.Recv() - if pkg == nil { + pkt, _ = stream.Recv() + if pkt == nil { t.Error("unexpected nil packet") } - if pkg.Type != client.PacketType_CLOSE_RSP { - t.Errorf("expect PacketType_CLOSE_RSP; got %+v", pkg) + if pkt.Type != client.PacketType_CLOSE_RSP { + t.Errorf("expect PacketType_CLOSE_RSP; got %+v", pkt) } - closeErr = pkg.Payload.(*client.Packet_CloseResponse).CloseResponse.Error + closeErr = pkt.Payload.(*client.Packet_CloseResponse).CloseResponse.Error if closeErr != "Unknown connectID" { t.Errorf("expect Unknown connectID; got %v", closeErr) } } +func TestConnectionMismatch(t *testing.T) { + var stream agent.AgentService_ConnectClient + stopCh := make(chan struct{}) + cs := &ClientSet{ + clients: make(map[string]*Client), + stopCh: stopCh, + } + testClient := &Client{ + connManager: newConnectionManager(), + stopCh: stopCh, + cs: cs, + } + testClient.stream, stream = pipe() + + // Start agent + go testClient.Serve() + defer close(stopCh) + + // Simulate sending a DATA packet to (Agent) Client + const connID = 12345 + pkt := newDataPacket(connID, []byte("hello world")) + if err := stream.Send(pkt); err != nil { + t.Fatal(err) + } + + // Expect to receive CLOSE_RSP packet from (Agent) Client + pkt, err := stream.Recv() + if err != nil { + t.Fatal(err) + } + if pkt.Type != client.PacketType_CLOSE_RSP { + t.Errorf("expect PacketType_CLOSE_RSP; got %v", pkt.Type) + } + closeRsp := pkt.Payload.(*client.Packet_CloseResponse).CloseResponse + if closeRsp.ConnectID != connID { + t.Errorf("expect connID=%d; got %v", connID, closeRsp.ConnectID) + } +} + // brokenStream wraps a ConnectClient and returns an error on Send and/or Recv if the respective // error is non-nil. type brokenStream struct { @@ -287,7 +326,7 @@ func TestFailedSend_DialResp_GRPC(t *testing.T) { time.Sleep(time.Second) defer goleakVerifyNone(t, goleak.IgnoreCurrent()) - // Stimulate sending KAS DIAL_REQ to (Agent) Client + // Simulate sending KAS DIAL_REQ to (Agent) Client dialPacket := newDialPacket("tcp", strings.TrimPrefix(ts.URL, "http://"), 111) err := stream.Send(dialPacket) if err != nil { @@ -329,9 +368,9 @@ func (s *fakeStream) Send(packet *client.Packet) error { func (s *fakeStream) Recv() (*client.Packet, error) { select { - case pkg := <-s.r: - klog.V(4).InfoS("[DEBUG] recv", "packet", pkg) - return pkg, nil + case pkt := <-s.r: + klog.V(4).InfoS("[DEBUG] recv", "packet", pkt) + return pkt, nil case <-time.After(5 * time.Second): return nil, errors.New("timeout recv") } diff --git a/pkg/server/server.go b/pkg/server/server.go index 44a9d1a2a..36dfa6d70 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -439,6 +439,21 @@ func (s *ProxyServer) serveRecvFrontend(stream client.ProxyService_ProxyServer, var backend Backend var err error + defer func() { + // As the read side of the recvCh channel, we cannot close it. + // However readFrontendToChannel() may be blocked writing to the channel, + // so we need to consume the channel until it is closed. + discardedPktCount := 0 + for range recvCh { + // Ignore values as this indicates there was a problem + // with the remote connection. + discardedPktCount++ + } + if discardedPktCount > 0 { + klog.V(2).InfoS("Discard packets while exiting serveRecvFrontend", "pktCount", discardedPktCount, "connectionID", firstConnID) + } + }() + for pkt := range recvCh { switch pkt.Type { case client.PacketType_DIAL_REQ: @@ -492,16 +507,19 @@ func (s *ProxyServer) serveRecvFrontend(stream client.ProxyService_ProxyServer, connID := pkt.GetCloseRequest().ConnectID klog.V(5).InfoS("Received CLOSE_REQ", "connectionID", connID) if backend == nil { - klog.V(2).InfoS("Backend has not been initialized for requested connection. Client should send a Dial Request first", - "connectionID", connID) + klog.V(2).InfoS("Backend has not been initialized for this connection", "connectionID", connID) + s.sendFrontendClose(stream, connID, "backend uninitialized") continue } if err := backend.Send(pkt); err != nil { // TODO: retry with other backends connecting to this agent. klog.ErrorS(err, "CLOSE_REQ to Backend failed", "connectionID", connID) + s.sendFrontendClose(stream, connID, "CLOSE_REQ to backend failed") } else { klog.V(5).InfoS("CLOSE_REQ sent to backend", "connectionID", connID) } + klog.V(3).InfoS("Closing frontend streaming per CLOSE_REQ", "connectionID", connID) + return case client.PacketType_DIAL_CLS: random := pkt.GetCloseDial().Random @@ -520,16 +538,29 @@ func (s *ProxyServer) serveRecvFrontend(stream client.ProxyService_ProxyServer, connID := pkt.GetData().ConnectID data := pkt.GetData().Data klog.V(5).InfoS("Received data from connection", "bytes", len(data), "connectionID", connID) + if backend == nil { + klog.V(2).InfoS("Backend has not been initialized for this connection", "connectionID", connID) + s.sendFrontendClose(stream, connID, "backend not initialized") + return + } + + if connID == 0 { + klog.ErrorS(nil, "Received packet missing ConnectID from frontend", "packetType", "DATA") + continue + } + if firstConnID == 0 { firstConnID = connID } else if firstConnID != connID { - klog.V(5).InfoS("Data does not match first connection id", "fistConnectionID", firstConnID, "connectionID", connID) + klog.ErrorS(nil, "Data does not match first connection id", "firstConnectionID", firstConnID, "connectionID", connID) + // Something went very wrong if we get here. Close both connections to avoid leaks. + s.sendBackendClose(backend, connID, 0, "mismatched connection IDs") + s.sendBackendClose(backend, firstConnID, 0, "mismatched connection IDs") + s.sendFrontendClose(stream, connID, "mismatched connection IDs") + s.sendFrontendClose(stream, firstConnID, "mismatched connection IDs") + return } - if backend == nil { - klog.V(2).InfoS("Backend has not been initialized for the connection. Client should send a Dial Request first", "connectionID", connID) - continue - } if err := backend.Send(pkt); err != nil { // TODO: retry with other backends connecting to this agent. klog.ErrorS(err, "DATA to Backend failed", "connectionID", connID) @@ -752,6 +783,18 @@ func (s *ProxyServer) readBackendToChannel(stream agent.AgentService_ConnectServ // route the packet back to the correct client func (s *ProxyServer) serveRecvBackend(backend Backend, stream agent.AgentService_ConnectServer, agentID string, recvCh <-chan *client.Packet) { + defer func() { + // Drain recvCh to ensure that readBackendToChannel is not blocked on a channel write. + // This should never happen, as termination of this function should only be initiated by closing recvCh. + discardedPktCount := 0 + for range recvCh { + discardedPktCount++ + } + if discardedPktCount > 0 { + klog.V(2).InfoS("Discard packets while exiting serveRecvBackend", "pktCount", discardedPktCount, "agentID", agentID) + } + }() + defer func() { // Close all connected frontends when the agent connection is closed // TODO(#126): Frontends in PendingDial state that have not been added to the @@ -787,7 +830,7 @@ func (s *ProxyServer) serveRecvBackend(backend Backend, stream agent.AgentServic klog.V(2).InfoS("DIAL_RSP not recognized; dropped", "dialID", resp.Random, "agentID", agentID, "connectionID", resp.ConnectID) metrics.Metrics.ObserveDialFailure(metrics.DialFailureUnrecognizedResponse) if resp.ConnectID != 0 { - s.sendCloseRequest(stream, resp.ConnectID, resp.Random, "failed to notify agent of closing due to unknown dial id") + s.sendBackendClose(stream, resp.ConnectID, resp.Random, "unknown dial id") } } else { dialErr := false @@ -808,7 +851,7 @@ func (s *ProxyServer) serveRecvBackend(backend Backend, stream agent.AgentServic // If we never finish setting up the tunnel for ConnectID, then the connection is dead. // Currently, the agent will no resend DIAL_RSP, so connection is dead. // We already attempted to tell the frontend that. We should ensure we tell the backend. - s.sendCloseRequest(stream, resp.ConnectID, resp.Random, "failed to notify agent of closing due to dial error") + s.sendBackendClose(stream, resp.ConnectID, resp.Random, "dial error") dialErr = true } // Avoid adding the frontend if there was an error dialing the destination @@ -854,9 +897,15 @@ func (s *ProxyServer) serveRecvBackend(backend Backend, stream agent.AgentServic case client.PacketType_DATA: resp := pkt.GetData() klog.V(5).InfoS("Received data from agent", "bytes", len(resp.Data), "agentID", agentID, "connectionID", resp.ConnectID) + if resp.ConnectID == 0 { + klog.ErrorS(nil, "Received packet missing ConnectID from agent", "packetType", "DATA") + continue + } + frontend, err := s.getFrontend(agentID, resp.ConnectID) if err != nil { - klog.ErrorS(err, "could not get frontend client", "agentID", agentID, "connectionID", resp.ConnectID) + klog.V(2).InfoS("could not get frontend client; closing connection", "agentID", agentID, "connectionID", resp.ConnectID, "error", err) + s.sendBackendClose(stream, resp.ConnectID, 0, "missing frontend") break } if err := frontend.send(pkt); err != nil { @@ -890,7 +939,7 @@ func (s *ProxyServer) serveRecvBackend(backend Backend, stream agent.AgentServic klog.V(5).InfoS("Close backend of agent", "backend", stream, "agentID", agentID) } -func (s *ProxyServer) sendCloseRequest(stream agent.AgentService_ConnectServer, connectID int64, random int64, failMsg string) { +func (s *ProxyServer) sendBackendClose(stream Backend, connectID int64, random int64, reason string) { pkt := &client.Packet{ Type: client.PacketType_CLOSE_REQ, Payload: &client.Packet_CloseRequest{ @@ -903,6 +952,24 @@ func (s *ProxyServer) sendCloseRequest(stream agent.AgentService_ConnectServer, metrics.Metrics.ObservePacket(segment, pkt.Type) if err := stream.Send(pkt); err != nil { metrics.Metrics.ObserveStreamError(segment, err, pkt.Type) - klog.V(5).ErrorS(err, failMsg, "dialID", random, "agentID", agentID, "connectionID", connectID) + klog.V(5).ErrorS(err, "Failed to send close to agent", "closeReason", reason, "dialID", random, "agentID", agentID, "connectionID", connectID) + } +} + +func (s *ProxyServer) sendFrontendClose(stream client.ProxyService_ProxyServer, connectID int64, reason string) { + pkt := &client.Packet{ + Type: client.PacketType_CLOSE_RSP, + Payload: &client.Packet_CloseResponse{ + CloseResponse: &client.CloseResponse{ + ConnectID: connectID, + Error: reason, + }, + }, + } + const segment = commonmetrics.SegmentToClient + metrics.Metrics.ObservePacket(segment, pkt.Type) + if err := stream.Send(pkt); err != nil { + metrics.Metrics.ObserveStreamError(segment, err, pkt.Type) + klog.V(5).ErrorS(err, "Failed to send close to frontend", "closeReason", reason, "connectionID", connectID) } } diff --git a/pkg/server/server_test.go b/pkg/server/server_test.go index 1400cd2bd..0f896509d 100644 --- a/pkg/server/server_test.go +++ b/pkg/server/server_test.go @@ -354,16 +354,13 @@ func TestServerProxyNormalClose(t *testing.T) { gomock.InOrder( frontendConn.EXPECT().Recv().Return(dialReq, nil).Times(1), frontendConn.EXPECT().Recv().Return(data, nil).Times(1), - frontendConn.EXPECT().Recv().Return(closePacket(connectID), nil).Times(1), + frontendConn.EXPECT().Recv().Return(closeReqPkt(connectID), nil).Times(1), frontendConn.EXPECT().Recv().Return(nil, io.EOF).Times(1), ) gomock.InOrder( agentConn.EXPECT().Send(dialReq).Return(nil).Times(1), agentConn.EXPECT().Send(data).Return(nil).Times(1), - agentConn.EXPECT().Send(closePacket(connectID)).Return(nil).Times(1), - // This extra close is unwanted and should be removed; see - // https://github.com/kubernetes-sigs/apiserver-network-proxy/pull/307 - agentConn.EXPECT().Send(closePacket(connectID)).Return(nil).Times(1), + agentConn.EXPECT().Send(closeReqPkt(connectID)).Return(nil).Times(1), ) } baseServerProxyTestWithBackend(t, validate) @@ -440,7 +437,10 @@ func TestServerProxyRecvChanFull(t *testing.T) { return data, nil }), - frontendConn.EXPECT().Recv().Return(closePacket(1), nil), + frontendConn.EXPECT().Recv().Return(closeReqPkt(1), nil), + // Ensure that the go-routines don't deadlock if more packets are received before closing the connection. + // This is a bit contrived, but exercises a possible failure scenario. + frontendConn.EXPECT().Recv().Return(data, nil).Times(xfrChannelSize+1), frontendConn.EXPECT().Recv().Return(nil, io.EOF), ) gomock.InOrder( @@ -455,16 +455,83 @@ func TestServerProxyRecvChanFull(t *testing.T) { return nil }), agentConn.EXPECT().Send(data).Return(nil).Times(xfrChannelSize+1), // Expect the remaining packets to be sent. - agentConn.EXPECT().Send(closePacket(1)).Return(nil), - // This extra close is unwanted and should be removed; see - // https://github.com/kubernetes-sigs/apiserver-network-proxy/pull/307 - agentConn.EXPECT().Send(closePacket(1)).Return(nil), + agentConn.EXPECT().Send(closeReqPkt(1)).Return(nil), ) } baseServerProxyTestWithBackend(t, validate) } -func closePacket(connectID int64) *client.Packet { +func TestServerProxyNoDial(t *testing.T) { + baseServerProxyTestWithBackend(t, func(frontendConn, agentConn *agentmock.MockAgentService_ConnectServer) { + const connectID = 123456 + data := &client.Packet{ + Type: client.PacketType_DATA, + Payload: &client.Packet_Data{ + Data: &client.Data{ + ConnectID: connectID, + }, + }, + } + + gomock.InOrder( + frontendConn.EXPECT().Recv().Return(data, nil), + frontendConn.EXPECT().Recv().Return(nil, io.EOF), + ) + frontendConn.EXPECT().Send(closeRspPkt(connectID, "backend not initialized")).Return(nil) + }) +} + +func TestServerProxyConnectionMismatch(t *testing.T) { + baseServerProxyTestWithBackend(t, func(frontendConn, agentConn *agentmock.MockAgentService_ConnectServer) { + const firstConnectID = 123456 + const secondConnectID = 654321 + dialReq := &client.Packet{ + Type: client.PacketType_DIAL_REQ, + Payload: &client.Packet_DialRequest{ + DialRequest: &client.DialRequest{ + Protocol: "tcp", + Address: "127.0.0.1:8080", + Random: 111, + }, + }, + } + data := &client.Packet{ + Type: client.PacketType_DATA, + Payload: &client.Packet_Data{ + Data: &client.Data{ + ConnectID: firstConnectID, + Data: []byte("hello"), + }, + }, + } + mismatchedData := &client.Packet{ + Type: client.PacketType_DATA, + Payload: &client.Packet_Data{ + Data: &client.Data{ + ConnectID: secondConnectID, + Data: []byte("world"), + }, + }, + } + + gomock.InOrder( + frontendConn.EXPECT().Recv().Return(dialReq, nil), + frontendConn.EXPECT().Recv().Return(data, nil), + frontendConn.EXPECT().Recv().Return(mismatchedData, nil), + frontendConn.EXPECT().Recv().Return(nil, io.EOF), + ) + gomock.InOrder( + agentConn.EXPECT().Send(dialReq).Return(nil), + agentConn.EXPECT().Send(data).Return(nil), + ) + agentConn.EXPECT().Send(closeReqPkt(secondConnectID)).Return(nil) + agentConn.EXPECT().Send(closeReqPkt(firstConnectID)).Return(nil) + frontendConn.EXPECT().Send(closeRspPkt(secondConnectID, "mismatched connection IDs")).Return(nil) + frontendConn.EXPECT().Send(closeRspPkt(firstConnectID, "mismatched connection IDs")).Return(nil) + }) +} + +func closeReqPkt(connectID int64) *client.Packet { return &client.Packet{ Type: client.PacketType_CLOSE_REQ, Payload: &client.Packet_CloseRequest{ @@ -473,3 +540,15 @@ func closePacket(connectID int64) *client.Packet { }}, } } + +func closeRspPkt(connectID int64, errMsg string) *client.Packet { + return &client.Packet{ + Type: client.PacketType_CLOSE_RSP, + Payload: &client.Packet_CloseResponse{ + CloseResponse: &client.CloseResponse{ + ConnectID: connectID, + Error: errMsg, + }, + }, + } +} diff --git a/tests/concurrent_test.go b/tests/concurrent_test.go index d6d906c39..d3730f865 100644 --- a/tests/concurrent_test.go +++ b/tests/concurrent_test.go @@ -47,16 +47,17 @@ func TestProxy_ConcurrencyGRPC(t *testing.T) { clientset := runAgent(proxy.agent, stopCh) waitForConnectedServerCount(t, 1, clientset) - // run test client - tunnel, err := client.CreateSingleUseGrpcTunnel(ctx, proxy.front, grpc.WithInsecure()) - if err != nil { - t.Fatal(err) - } - var wg sync.WaitGroup verify := func() { defer wg.Done() + // run test client + tunnel, err := client.CreateSingleUseGrpcTunnel(ctx, proxy.front, grpc.WithInsecure()) + if err != nil { + t.Error(err) + return + } + c := &http.Client{ Transport: &http.Transport{ DialContext: tunnel.DialContext, @@ -66,11 +67,13 @@ func TestProxy_ConcurrencyGRPC(t *testing.T) { r, err := c.Get(server.URL) if err != nil { t.Error(err) + return } data, err := ioutil.ReadAll(r.Body) if err != nil { t.Error(err) + return } defer r.Body.Close()