From 6d1fe9c5aaf6daf20a4711799ba4ce778129e1ca Mon Sep 17 00:00:00 2001 From: Kian Parvin Date: Mon, 13 Jan 2025 17:18:41 +0200 Subject: [PATCH 1/2] chore: move model proxy to separate pkg The model proxy is currently housed within the rpc package. Originally I did this because the model proxy relied on the Juju RPC message definition declared in the rpc package. But this is likely not the right place for the model proxy as it has nothing to do with the dialer logic in the rpc package. Another supporting argument that the model proxy was in the wrong package is that the jimm package relies on the rpc package for dialing logic, but does not rely on the model proxy at all. And in fact, instead, the model proxy indirectly relies on the jimm package for business logic. --- internal/jujuapi/streamproxy.go | 4 +- internal/jujuapi/websocket.go | 17 +- internal/rpc/client.go | 12 +- internal/rpc/client_test.go | 333 +----------------- internal/rpc/dial _test.go | 7 +- internal/rpc/export_test.go | 5 +- internal/rpc/rpc.go | 4 +- internal/rpcproxy/export_test.go | 5 + internal/rpcproxy/message.go | 30 ++ .../{rpc/apiproxy.go => rpcproxy/rpcproxy.go} | 8 +- internal/rpcproxy/rpcproxy_test.go | 274 ++++++++++++++ .../rpcproxylogin_test.go} | 83 ++--- internal/{rpc => streamproxy}/streamproxy.go | 17 +- .../{rpc => streamproxy}/streamproxy_test.go | 44 ++- internal/testutils/rpctest/server.go | 96 +++++ 15 files changed, 514 insertions(+), 425 deletions(-) create mode 100644 internal/rpcproxy/export_test.go create mode 100644 internal/rpcproxy/message.go rename internal/{rpc/apiproxy.go => rpcproxy/rpcproxy.go} (98%) create mode 100644 internal/rpcproxy/rpcproxy_test.go rename internal/{rpc/apiproxy_test.go => rpcproxy/rpcproxylogin_test.go} (84%) rename internal/{rpc => streamproxy}/streamproxy.go (73%) rename internal/{rpc => streamproxy}/streamproxy_test.go (61%) create mode 100644 internal/testutils/rpctest/server.go diff --git a/internal/jujuapi/streamproxy.go b/internal/jujuapi/streamproxy.go index 5500c849a..fcbc3a9b0 100644 --- a/internal/jujuapi/streamproxy.go +++ b/internal/jujuapi/streamproxy.go @@ -17,7 +17,7 @@ import ( "github.com/canonical/jimm/v3/internal/errors" "github.com/canonical/jimm/v3/internal/jimmhttp" "github.com/canonical/jimm/v3/internal/openfga" - jimmRPC "github.com/canonical/jimm/v3/internal/rpc" + "github.com/canonical/jimm/v3/internal/streamproxy" ) // A streamProxier serves the the /log endpoint by proxying @@ -103,7 +103,7 @@ func (s streamProxier) ServeWS(ctx context.Context, clientConn *websocket.Conn) return } - jimmRPC.ProxyStreams(ctx, clientConn, controllerStream) + streamproxy.ProxyStreams(ctx, clientConn, controllerStream) } func checkPermission(ctx context.Context, path string, u *openfga.User, mt names.ModelTag) (bool, error) { diff --git a/internal/jujuapi/websocket.go b/internal/jujuapi/websocket.go index 981f83abc..325a0ed34 100644 --- a/internal/jujuapi/websocket.go +++ b/internal/jujuapi/websocket.go @@ -24,6 +24,7 @@ import ( "github.com/canonical/jimm/v3/internal/jimm/jujuauth" "github.com/canonical/jimm/v3/internal/jimmhttp" jimmRPC "github.com/canonical/jimm/v3/internal/rpc" + "github.com/canonical/jimm/v3/internal/rpcproxy" ) const ( @@ -177,7 +178,7 @@ func (s apiProxier) ServeWS(ctx context.Context, clientConn *websocket.Conn) { connectionFunc := controllerConnectionFunc(s, &jwtGenerator) zapctx.Debug(ctx, "Starting proxier") auditLogger := s.jimm.AuditLogManager().AddAuditLogEntry - proxyHelpers := jimmRPC.ProxyHelpers{ + proxyHelpers := rpcproxy.ProxyHelpers{ ConnClient: clientConn, TokenGen: &jwtGenerator, ConnectController: connectionFunc, @@ -185,22 +186,22 @@ func (s apiProxier) ServeWS(ctx context.Context, clientConn *websocket.Conn) { LoginService: s.jimm.LoginManager(), AuthenticatedIdentityID: auth.SessionIdentityFromContext(ctx), } - if err := jimmRPC.ProxySockets(ctx, proxyHelpers); err != nil { + if err := rpcproxy.ProxySockets(ctx, proxyHelpers); err != nil { zapctx.Error(ctx, "failed to start jimm model proxy", zap.Error(err)) } } // controllerConnectionFunc returns a function that will be used to // connect to a controller when a client makes a request. -func controllerConnectionFunc(s apiProxier, jwtGenerator *jujuauth.TokenGenerator) func(context.Context) (jimmRPC.WebsocketConnectionWithMetadata, error) { - return func(ctx context.Context) (jimmRPC.WebsocketConnectionWithMetadata, error) { +func controllerConnectionFunc(s apiProxier, jwtGenerator *jujuauth.TokenGenerator) func(context.Context) (rpcproxy.WebsocketConnectionWithMetadata, error) { + return func(ctx context.Context) (rpcproxy.WebsocketConnectionWithMetadata, error) { const op = errors.Op("proxy.controllerConnectionFunc") path := jimmhttp.PathElementFromContext(ctx, "path") zapctx.Debug(ctx, "grabbing model info from path", zap.String("path", path)) uuid, finalPath, err := modelInfoFromPath(path) if err != nil { zapctx.Error(ctx, "error parsing path", zap.Error(err)) - return jimmRPC.WebsocketConnectionWithMetadata{}, errors.E(op, err) + return rpcproxy.WebsocketConnectionWithMetadata{}, errors.E(op, err) } m := dbmodel.Model{ UUID: sql.NullString{ @@ -210,7 +211,7 @@ func controllerConnectionFunc(s apiProxier, jwtGenerator *jujuauth.TokenGenerato } if err := s.jimm.Database.GetModel(context.Background(), &m); err != nil { zapctx.Error(ctx, "failed to find model", zap.String("uuid", uuid), zap.Error(err)) - return jimmRPC.WebsocketConnectionWithMetadata{}, errors.E(err, errors.CodeNotFound) + return rpcproxy.WebsocketConnectionWithMetadata{}, errors.E(err, errors.CodeNotFound) } jwtGenerator.SetTags(m.ResourceTag(), m.Controller.ResourceTag()) mt := m.ResourceTag() @@ -218,10 +219,10 @@ func controllerConnectionFunc(s apiProxier, jwtGenerator *jujuauth.TokenGenerato controllerConn, err := jimmRPC.Dial(ctx, &m.Controller, mt, finalPath, nil) if err != nil { zapctx.Error(ctx, "cannot dial controller", zap.String("controller", m.Controller.Name), zap.Error(err)) - return jimmRPC.WebsocketConnectionWithMetadata{}, err + return rpcproxy.WebsocketConnectionWithMetadata{}, err } fullModelName := m.Controller.Name + "/" + m.Name - return jimmRPC.WebsocketConnectionWithMetadata{ + return rpcproxy.WebsocketConnectionWithMetadata{ Conn: controllerConn, ControllerUUID: m.Controller.UUID, ModelName: fullModelName, diff --git a/internal/rpc/client.go b/internal/rpc/client.go index 61ec4f646..059ef5db7 100644 --- a/internal/rpc/client.go +++ b/internal/rpc/client.go @@ -1,4 +1,4 @@ -// Copyright 2024 Canonical. +// Copyright 2025 Canonical. package rpc @@ -203,16 +203,6 @@ func (c *Client) Call(ctx context.Context, facade string, version int, id, metho select { case <-ch: - permissionsRequired, err := checkPermissionsRequired(ctx, *respMsg) - if err != nil { - return err - } - if permissionsRequired != nil { - return &Error{ - Code: PermissionCheckRequiredErrorCode, - Info: permissionsRequired, - } - } if (*respMsg).Error != "" { return &Error{ Message: (*respMsg).Error, diff --git a/internal/rpc/client_test.go b/internal/rpc/client_test.go index 7e94de97e..258da821c 100644 --- a/internal/rpc/client_test.go +++ b/internal/rpc/client_test.go @@ -1,4 +1,4 @@ -// Copyright 2024 Canonical. +// Copyright 2025 Canonical. package rpc_test @@ -6,28 +6,24 @@ import ( "context" "crypto/tls" "crypto/x509" - "encoding/json" "net" "net/http" "net/http/httptest" "strings" "testing" - "time" qt "github.com/frankban/quicktest" "github.com/gorilla/websocket" - "github.com/juju/names/v5" - "github.com/canonical/jimm/v3/internal/dbmodel" "github.com/canonical/jimm/v3/internal/errors" - "github.com/canonical/jimm/v3/internal/openfga" "github.com/canonical/jimm/v3/internal/rpc" + "github.com/canonical/jimm/v3/internal/testutils/rpctest" ) func TestDialError(t *testing.T) { c := qt.New(t) - srv := newServer(echo) + srv := newServer(rpctest.Echo) defer srv.Close() d := *srv.dialer d.TLSConfig = nil @@ -38,7 +34,7 @@ func TestDialError(t *testing.T) { func TestDial(t *testing.T) { c := qt.New(t) - srv := newServer(echo) + srv := newServer(rpctest.Echo) defer srv.Close() conn, err := srv.dialer.Dial(context.Background(), srv.URL, nil) c.Assert(err, qt.IsNil) @@ -48,7 +44,7 @@ func TestDial(t *testing.T) { func TestBasicDial(t *testing.T) { c := qt.New(t) - srv := newServer(echo) + srv := newServer(rpctest.Echo) defer srv.Close() conn, err := srv.dialer.DialWebsocket(context.Background(), srv.URL, nil) c.Assert(err, qt.IsNil) @@ -58,7 +54,7 @@ func TestBasicDial(t *testing.T) { func TestCallSuccess(t *testing.T) { c := qt.New(t) - srv := newServer(echo) + srv := newServer(rpctest.Echo) defer srv.Close() conn, err := srv.dialer.Dial(context.Background(), srv.URL, nil) c.Assert(err, qt.IsNil) @@ -76,7 +72,7 @@ func TestCallSuccess(t *testing.T) { func TestCallCanceledContext(t *testing.T) { c := qt.New(t) - srv := newServer(echo) + srv := newServer(rpctest.Echo) defer srv.Close() conn, err := srv.dialer.Dial(context.Background(), srv.URL, nil) c.Assert(err, qt.IsNil) @@ -137,7 +133,7 @@ func TestCallErrorResponse(t *testing.T) { if err := conn.WriteJSON(resp); err != nil { return err } - return echo(conn) + return rpctest.Echo(conn) }) defer srv.Close() conn, err := srv.dialer.Dial(context.Background(), srv.URL, nil) @@ -181,7 +177,7 @@ func TestClientReceiveRequest(t *testing.T) { if err := conn.WriteJSON(req2); err != nil { return err } - return echo(conn) + return rpctest.Echo(conn) }) defer srv.Close() conn, err := srv.dialer.Dial(context.Background(), srv.URL, nil) @@ -211,7 +207,7 @@ func TestClientReceiveInvalidMessage(t *testing.T) { if err := conn.WriteJSON(struct{}{}); err != nil { return err } - return echo(conn) + return rpctest.Echo(conn) }) defer srv.Close() conn, err := srv.dialer.Dial(context.Background(), srv.URL, nil) @@ -224,268 +220,6 @@ func TestClientReceiveInvalidMessage(t *testing.T) { c.Check(res, qt.Equals, "") } -type testTokenGenerator struct{} - -func (p *testTokenGenerator) MakeLoginToken(ctx context.Context, user *openfga.User) ([]byte, error) { - return nil, nil -} - -func (p *testTokenGenerator) MakeToken(ctx context.Context, permissionMap map[string]interface{}) ([]byte, error) { - return nil, nil -} - -func (p *testTokenGenerator) SetTags(names.ModelTag, names.ControllerTag) { -} - -func (p *testTokenGenerator) GetUser() names.UserTag { - return names.NewUserTag("testUser") -} - -func TestProxySockets(t *testing.T) { - c := qt.New(t) - ctx := context.Background() - - srvController := newServer(echo) - - errChan := make(chan error) - srvJIMM := newServer(func(connClient *websocket.Conn) error { - testTokenGen := testTokenGenerator{} - f := func(context.Context) (rpc.WebsocketConnectionWithMetadata, error) { - connController, err := srvController.dialer.DialWebsocket(ctx, srvController.URL, nil) - c.Check(err, qt.IsNil) - return rpc.WebsocketConnectionWithMetadata{ - Conn: connController, - ModelName: "TestName", - }, nil - } - auditLogger := func(ale *dbmodel.AuditLogEntry) {} - proxyHelpers := rpc.ProxyHelpers{ - ConnClient: connClient, - TokenGen: &testTokenGen, - ConnectController: f, - AuditLog: auditLogger, - LoginService: &mockLoginService{}, - } - err := rpc.ProxySockets(ctx, proxyHelpers) - c.Check(err, qt.IsNil) - errChan <- err - return err - }) - - defer srvController.Close() - defer srvJIMM.Close() - ws, err := srvJIMM.dialer.DialWebsocket(ctx, srvJIMM.URL, nil) - c.Assert(err, qt.IsNil) - defer ws.Close() - - p := json.RawMessage(`{"Key":"TestVal"}`) - msg := rpc.Message{RequestID: 1, Type: "TestType", Request: "TestReq", Params: p} - err = ws.WriteJSON(&msg) - c.Assert(err, qt.IsNil) - resp := rpc.Message{} - receiveChan := make(chan error) - go func() { - receiveChan <- ws.ReadJSON(&resp) - }() - select { - case err := <-receiveChan: - c.Assert(err, qt.IsNil) - case <-time.After(5 * time.Second): - c.Logf("took too long to read response") - c.FailNow() - } - c.Assert(resp.Response, qt.DeepEquals, msg.Params) - ws.Close() - <-errChan // Ensure go routines are cleaned up -} - -func TestProxySocketsControllerConnectionFails(t *testing.T) { - c := qt.New(t) - ctx := context.Background() - - srvController := newServer(echo) - - var connController *websocket.Conn - errChan := make(chan error) - srvJIMM := newServer(func(connClient *websocket.Conn) error { - testTokenGen := testTokenGenerator{} - f := func(context.Context) (rpc.WebsocketConnectionWithMetadata, error) { - var err error - connController, err = srvController.dialer.DialWebsocket(ctx, srvController.URL, nil) - c.Check(err, qt.IsNil) - return rpc.WebsocketConnectionWithMetadata{ - Conn: connController, - ModelName: "TestName", - }, nil - } - auditLogger := func(ale *dbmodel.AuditLogEntry) {} - proxyHelpers := rpc.ProxyHelpers{ - ConnClient: connClient, - TokenGen: &testTokenGen, - ConnectController: f, - AuditLog: auditLogger, - LoginService: &mockLoginService{}, - } - err := rpc.ProxySockets(ctx, proxyHelpers) - c.Check(err, qt.IsNil) - errChan <- err - return err - }) - - defer srvController.Close() - defer srvJIMM.Close() - ws, err := srvJIMM.dialer.DialWebsocket(ctx, srvJIMM.URL, nil) - c.Assert(err, qt.IsNil) - defer ws.Close() - - p := json.RawMessage(`{"Key":"TestVal"}`) - msg := rpc.Message{RequestID: 1, Type: "TestType", Request: "TestReq", Params: p} - err = ws.WriteJSON(&msg) - c.Assert(err, qt.IsNil) - resp := rpc.Message{} - receiveChan := make(chan error) - go func() { - receiveChan <- ws.ReadJSON(&resp) - }() - select { - case err := <-receiveChan: - c.Assert(err, qt.IsNil) - case <-time.After(5 * time.Second): - c.Logf("took too long to read response") - c.FailNow() - } - c.Assert(resp.Response, qt.DeepEquals, msg.Params) - - // Now close the connection to the controller and ensure the model proxy is cleaned up. - connController.Close() - <-errChan // Ensure go routines are cleaned up -} - -func TestCancelProxySockets(t *testing.T) { - c := qt.New(t) - - ctx, cancel := context.WithCancel(context.Background()) - - srvController := newServer(echo) - - errChan := make(chan error) - srvJIMM := newServer(func(connClient *websocket.Conn) error { - testTokenGen := testTokenGenerator{} - f := func(context.Context) (rpc.WebsocketConnectionWithMetadata, error) { - connController, err := srvController.dialer.DialWebsocket(ctx, srvController.URL, nil) - c.Check(err, qt.IsNil) - return rpc.WebsocketConnectionWithMetadata{ - Conn: connController, - ModelName: "TestName", - }, nil - } - auditLogger := func(ale *dbmodel.AuditLogEntry) {} - proxyHelpers := rpc.ProxyHelpers{ - ConnClient: connClient, - TokenGen: &testTokenGen, - ConnectController: f, - AuditLog: auditLogger, - LoginService: &mockLoginService{}, - } - err := rpc.ProxySockets(ctx, proxyHelpers) - c.Check(err, qt.ErrorMatches, "Context cancelled") - errChan <- err - return err - }) - - defer srvController.Close() - defer srvJIMM.Close() - ws, err := srvJIMM.dialer.DialWebsocket(ctx, srvJIMM.URL, nil) - c.Assert(err, qt.IsNil) - defer ws.Close() - cancel() - <-errChan -} - -func TestProxySocketsAuditLogs(t *testing.T) { - c := qt.New(t) - - ctx := context.Background() - - srvController := newServer(echo) - auditLogs := make([]*dbmodel.AuditLogEntry, 0) - - errChan := make(chan error) - srvJIMM := newServer(func(connClient *websocket.Conn) error { - defer connClient.Close() - testTokenGen := testTokenGenerator{} - f := func(context.Context) (rpc.WebsocketConnectionWithMetadata, error) { - connController, err := srvController.dialer.DialWebsocket(ctx, srvController.URL, nil) - c.Check(err, qt.IsNil) - return rpc.WebsocketConnectionWithMetadata{ - Conn: connController, - ModelName: "TestModelName", - }, nil - } - auditLogger := func(ale *dbmodel.AuditLogEntry) { auditLogs = append(auditLogs, ale) } - proxyHelpers := rpc.ProxyHelpers{ - ConnClient: connClient, - TokenGen: &testTokenGen, - ConnectController: f, - AuditLog: auditLogger, - LoginService: &mockLoginService{}, - } - err := rpc.ProxySockets(ctx, proxyHelpers) - c.Check(err, qt.IsNil) - errChan <- err - return err - }) - - defer srvController.Close() - defer srvJIMM.Close() - ws, err := srvJIMM.dialer.DialWebsocket(ctx, srvJIMM.URL, nil) - c.Assert(err, qt.IsNil) - defer ws.Close() - - p := json.RawMessage(`{"Key":"TestVal"}`) - msg := rpc.Message{RequestID: 1, Type: "TestType", Request: "TestReq", Params: p} - err = ws.WriteJSON(&msg) - c.Assert(err, qt.IsNil) - resp := rpc.Message{} - err = ws.ReadJSON(&resp) - c.Assert(err, qt.IsNil) - ws.Close() - <-errChan // Ensure go routines are cleaned up - c.Assert(auditLogs, qt.HasLen, 2) - expectedEvents := []*dbmodel.AuditLogEntry{{ - ID: auditLogs[0].ID, - Time: auditLogs[0].Time, - Model: "TestModelName", - ConversationId: auditLogs[0].ConversationId, - MessageId: 1, - FacadeName: "TestType", - FacadeMethod: "TestReq", - FacadeVersion: 0, - ObjectId: "", - IdentityTag: "user-testUser", - IsResponse: false, - Params: dbmodel.JSON(p), - Errors: nil, - }, { - ID: auditLogs[1].ID, - Time: auditLogs[1].Time, - Model: "TestModelName", - ConversationId: auditLogs[1].ConversationId, - MessageId: 1, - FacadeName: "", - FacadeMethod: "", - FacadeVersion: 0, - ObjectId: "", - IdentityTag: "user-testUser", - IsResponse: true, - Params: nil, - Errors: auditLogs[1].Errors, - }, - } - c.Assert(auditLogs, qt.DeepEquals, expectedEvents) - -} - type server struct { *httptest.Server @@ -495,7 +229,7 @@ type server struct { func newServer(f func(*websocket.Conn) error) *server { var srv server - srv.Server = httptest.NewTLSServer(handleWS(f)) + srv.Server = httptest.NewTLSServer(rpctest.HandleWS(f)) srv.URL = "ws" + strings.TrimPrefix(srv.Server.URL, "http") cp := x509.NewCertPool() cp.AddCert(srv.Certificate()) @@ -513,7 +247,7 @@ func newIPv6Server(f func(*websocket.Conn) error) *server { l, _ := net.Listen("tcp", "[::1]:0") server := httptest.Server{ Listener: l, - Config: &http.Server{Handler: handleWS(f)}, //nolint:gosec + Config: &http.Server{Handler: rpctest.HandleWS(f)}, //nolint:gosec } server.StartTLS() srv.Server = &server @@ -528,46 +262,3 @@ func newIPv6Server(f func(*websocket.Conn) error) *server { } return &srv } - -func handleWS(f func(*websocket.Conn) error) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - var u websocket.Upgrader - c, err := u.Upgrade(w, req, nil) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - defer c.Close() - err = f(c) - var cm []byte - closeError, isCloseError := err.(*websocket.CloseError) - switch { - case err == nil: - cm = websocket.FormatCloseMessage(websocket.CloseNormalClosure, "") - case isCloseError: - cm = websocket.FormatCloseMessage(closeError.Code, closeError.Text) - default: - cm = websocket.FormatCloseMessage(websocket.CloseInternalServerErr, err.Error()) - } - _ = c.WriteControl(websocket.CloseMessage, cm, time.Time{}) - - }) -} - -func echo(c *websocket.Conn) error { - for { - msg := make(map[string]interface{}) - if err := c.ReadJSON(&msg); err != nil { - return err - } - delete(msg, "type") - delete(msg, "version") - delete(msg, "id") - delete(msg, "request") - msg["response"] = msg["params"] - delete(msg, "params") - if err := c.WriteJSON(msg); err != nil { - return err - } - } -} diff --git a/internal/rpc/dial _test.go b/internal/rpc/dial _test.go index bae202678..ac8d37a09 100644 --- a/internal/rpc/dial _test.go +++ b/internal/rpc/dial _test.go @@ -1,4 +1,4 @@ -// Copyright 2024 Canonical. +// Copyright 2025 Canonical. package rpc_test @@ -15,12 +15,13 @@ import ( "github.com/canonical/jimm/v3/internal/dbmodel" "github.com/canonical/jimm/v3/internal/rpc" + "github.com/canonical/jimm/v3/internal/testutils/rpctest" ) func TestDialIPv4(t *testing.T) { c := qt.New(t) ctx := context.Background() - fakeController := newServer(echo) + fakeController := newServer(rpctest.Echo) defer fakeController.Close() controller := dbmodel.Controller{} pemData := pem.EncodeToMemory(&pem.Block{ @@ -44,7 +45,7 @@ func TestDialIPv4(t *testing.T) { func TestDialIPv6(t *testing.T) { c := qt.New(t) ctx := context.Background() - fakeController := newIPv6Server(echo) + fakeController := newIPv6Server(rpctest.Echo) defer fakeController.Close() controller := dbmodel.Controller{} pemData := pem.EncodeToMemory(&pem.Block{ diff --git a/internal/rpc/export_test.go b/internal/rpc/export_test.go index a50d4d277..583c041e9 100644 --- a/internal/rpc/export_test.go +++ b/internal/rpc/export_test.go @@ -1,4 +1,5 @@ -// Copyright 2024 Canonical. +// Copyright 2025 Canonical. + package rpc -type Message message +type Message = message diff --git a/internal/rpc/rpc.go b/internal/rpc/rpc.go index 4090942a3..b7e7fc4d2 100644 --- a/internal/rpc/rpc.go +++ b/internal/rpc/rpc.go @@ -1,4 +1,4 @@ -// Copyright 2024 Canonical. +// Copyright 2025 Canonical. // Package rpc implements the juju RPC protocol. The main difference // between this implementation and the implementation in @@ -9,14 +9,12 @@ package rpc import ( "encoding/json" - "time" ) // A message encodes a single message sent, or received, over an RPC // connection. It contains the union of fields in a request or response // message. type message struct { - start time.Time RequestID uint64 `json:"request-id,omitempty"` Type string `json:"type,omitempty"` Version int `json:"version,omitempty"` diff --git a/internal/rpcproxy/export_test.go b/internal/rpcproxy/export_test.go new file mode 100644 index 000000000..a58cb9249 --- /dev/null +++ b/internal/rpcproxy/export_test.go @@ -0,0 +1,5 @@ +// Copyright 2025 Canonical. + +package rpcproxy + +type Message = message diff --git a/internal/rpcproxy/message.go b/internal/rpcproxy/message.go new file mode 100644 index 000000000..9fad5e478 --- /dev/null +++ b/internal/rpcproxy/message.go @@ -0,0 +1,30 @@ +// Copyright 2025 Canonical. + +package rpcproxy + +import ( + "encoding/json" + "time" +) + +// A message encodes a single message sent, or received, over an RPC +// connection. It contains the union of fields in a request or response +// message. +type message struct { + start time.Time + RequestID uint64 `json:"request-id,omitempty"` + Type string `json:"type,omitempty"` + Version int `json:"version,omitempty"` + ID string `json:"id,omitempty"` + Request string `json:"request,omitempty"` + Params json.RawMessage `json:"params,omitempty"` + Error string `json:"error,omitempty"` + ErrorCode string `json:"error-code,omitempty"` + ErrorInfo map[string]interface{} `json:"error-info,omitempty"` + Response json.RawMessage `json:"response,omitempty"` +} + +// isRequest returns whether the message is a request +func (m message) isRequest() bool { + return m.Type != "" && m.Request != "" +} diff --git a/internal/rpc/apiproxy.go b/internal/rpcproxy/rpcproxy.go similarity index 98% rename from internal/rpc/apiproxy.go rename to internal/rpcproxy/rpcproxy.go index f912d4138..e3375f0d4 100644 --- a/internal/rpc/apiproxy.go +++ b/internal/rpcproxy/rpcproxy.go @@ -1,5 +1,9 @@ -// Copyright 2024 Canonical. -package rpc +// Copyright 2025 Canonical. + +// Package rpcproxy implements a proxy for Juju's RPC messages. +// The rpcproxy is used to proxy messages between jimm and model facades +// on Juju controllers while still acting as an authorisation and routing layer. +package rpcproxy import ( "context" diff --git a/internal/rpcproxy/rpcproxy_test.go b/internal/rpcproxy/rpcproxy_test.go new file mode 100644 index 000000000..03030e710 --- /dev/null +++ b/internal/rpcproxy/rpcproxy_test.go @@ -0,0 +1,274 @@ +// Copyright 2025 Canonical. + +package rpcproxy_test + +import ( + "context" + "encoding/json" + "testing" + "time" + + qt "github.com/frankban/quicktest" + "github.com/gorilla/websocket" + "github.com/juju/names/v5" + + "github.com/canonical/jimm/v3/internal/dbmodel" + "github.com/canonical/jimm/v3/internal/openfga" + "github.com/canonical/jimm/v3/internal/rpcproxy" + "github.com/canonical/jimm/v3/internal/testutils/rpctest" +) + +type testTokenGenerator struct{} + +func (p *testTokenGenerator) MakeLoginToken(ctx context.Context, user *openfga.User) ([]byte, error) { + return nil, nil +} + +func (p *testTokenGenerator) MakeToken(ctx context.Context, permissionMap map[string]interface{}) ([]byte, error) { + return nil, nil +} + +func (p *testTokenGenerator) SetTags(names.ModelTag, names.ControllerTag) { +} + +func (p *testTokenGenerator) GetUser() names.UserTag { + return names.NewUserTag("testUser") +} + +func TestProxySockets(t *testing.T) { + c := qt.New(t) + ctx := context.Background() + + srvController := rpctest.NewServer(rpctest.Echo) + + errChan := make(chan error) + srvJIMM := rpctest.NewServer(func(connClient *websocket.Conn) error { + testTokenGen := testTokenGenerator{} + f := func(context.Context) (rpcproxy.WebsocketConnectionWithMetadata, error) { + connController := srvController.Dialer.DialWebsocket(c, srvController.URL) + return rpcproxy.WebsocketConnectionWithMetadata{ + Conn: connController, + ModelName: "TestName", + }, nil + } + auditLogger := func(ale *dbmodel.AuditLogEntry) {} + proxyHelpers := rpcproxy.ProxyHelpers{ + ConnClient: connClient, + TokenGen: &testTokenGen, + ConnectController: f, + AuditLog: auditLogger, + LoginService: &mockLoginService{}, + } + err := rpcproxy.ProxySockets(ctx, proxyHelpers) + c.Check(err, qt.IsNil) + errChan <- err + return err + }) + + defer srvController.Close() + defer srvJIMM.Close() + ws := srvJIMM.Dialer.DialWebsocket(c, srvJIMM.URL) + defer ws.Close() + + p := json.RawMessage(`{"Key":"TestVal"}`) + msg := rpcproxy.Message{RequestID: 1, Type: "TestType", Request: "TestReq", Params: p} + err := ws.WriteJSON(&msg) + c.Assert(err, qt.IsNil) + resp := rpcproxy.Message{} + receiveChan := make(chan error) + go func() { + receiveChan <- ws.ReadJSON(&resp) + }() + select { + case err := <-receiveChan: + c.Assert(err, qt.IsNil) + case <-time.After(5 * time.Second): + c.Logf("took too long to read response") + c.FailNow() + } + c.Assert(resp.Response, qt.DeepEquals, msg.Params) + ws.Close() + <-errChan // Ensure go routines are cleaned up +} + +func TestProxySocketsControllerConnectionFails(t *testing.T) { + c := qt.New(t) + ctx := context.Background() + + srvController := rpctest.NewServer(rpctest.Echo) + + var connController *websocket.Conn + errChan := make(chan error) + srvJIMM := rpctest.NewServer(func(connClient *websocket.Conn) error { + testTokenGen := testTokenGenerator{} + f := func(context.Context) (rpcproxy.WebsocketConnectionWithMetadata, error) { + var err error + connController = srvController.Dialer.DialWebsocket(c, srvController.URL) + c.Check(err, qt.IsNil) + return rpcproxy.WebsocketConnectionWithMetadata{ + Conn: connController, + ModelName: "TestName", + }, nil + } + auditLogger := func(ale *dbmodel.AuditLogEntry) {} + proxyHelpers := rpcproxy.ProxyHelpers{ + ConnClient: connClient, + TokenGen: &testTokenGen, + ConnectController: f, + AuditLog: auditLogger, + LoginService: &mockLoginService{}, + } + err := rpcproxy.ProxySockets(ctx, proxyHelpers) + c.Check(err, qt.IsNil) + errChan <- err + return err + }) + + defer srvController.Close() + defer srvJIMM.Close() + ws := srvJIMM.Dialer.DialWebsocket(c, srvJIMM.URL) + defer ws.Close() + + p := json.RawMessage(`{"Key":"TestVal"}`) + msg := rpcproxy.Message{RequestID: 1, Type: "TestType", Request: "TestReq", Params: p} + err := ws.WriteJSON(&msg) + c.Assert(err, qt.IsNil) + resp := rpcproxy.Message{} + receiveChan := make(chan error) + go func() { + receiveChan <- ws.ReadJSON(&resp) + }() + select { + case err := <-receiveChan: + c.Assert(err, qt.IsNil) + case <-time.After(5 * time.Second): + c.Logf("took too long to read response") + c.FailNow() + } + c.Assert(resp.Response, qt.DeepEquals, msg.Params) + + // Now close the connection to the controller and ensure the model proxy is cleaned up. + connController.Close() + <-errChan // Ensure go routines are cleaned up +} + +func TestCancelProxySockets(t *testing.T) { + c := qt.New(t) + + ctx, cancel := context.WithCancel(context.Background()) + + srvController := rpctest.NewServer(rpctest.Echo) + + errChan := make(chan error) + srvJIMM := rpctest.NewServer(func(connClient *websocket.Conn) error { + testTokenGen := testTokenGenerator{} + f := func(context.Context) (rpcproxy.WebsocketConnectionWithMetadata, error) { + connController := srvController.Dialer.DialWebsocket(c, srvController.URL) + return rpcproxy.WebsocketConnectionWithMetadata{ + Conn: connController, + ModelName: "TestName", + }, nil + } + auditLogger := func(ale *dbmodel.AuditLogEntry) {} + proxyHelpers := rpcproxy.ProxyHelpers{ + ConnClient: connClient, + TokenGen: &testTokenGen, + ConnectController: f, + AuditLog: auditLogger, + LoginService: &mockLoginService{}, + } + err := rpcproxy.ProxySockets(ctx, proxyHelpers) + c.Check(err, qt.ErrorMatches, "Context cancelled") + errChan <- err + return err + }) + + defer srvController.Close() + defer srvJIMM.Close() + ws := srvJIMM.Dialer.DialWebsocket(c, srvJIMM.URL) + defer ws.Close() + cancel() + <-errChan +} + +func TestProxySocketsAuditLogs(t *testing.T) { + c := qt.New(t) + + ctx := context.Background() + + srvController := rpctest.NewServer(rpctest.Echo) + auditLogs := make([]*dbmodel.AuditLogEntry, 0) + + errChan := make(chan error) + srvJIMM := rpctest.NewServer(func(connClient *websocket.Conn) error { + defer connClient.Close() + testTokenGen := testTokenGenerator{} + f := func(context.Context) (rpcproxy.WebsocketConnectionWithMetadata, error) { + connController := srvController.Dialer.DialWebsocket(c, srvController.URL) + return rpcproxy.WebsocketConnectionWithMetadata{ + Conn: connController, + ModelName: "TestModelName", + }, nil + } + auditLogger := func(ale *dbmodel.AuditLogEntry) { auditLogs = append(auditLogs, ale) } + proxyHelpers := rpcproxy.ProxyHelpers{ + ConnClient: connClient, + TokenGen: &testTokenGen, + ConnectController: f, + AuditLog: auditLogger, + LoginService: &mockLoginService{}, + } + err := rpcproxy.ProxySockets(ctx, proxyHelpers) + c.Check(err, qt.IsNil) + errChan <- err + return err + }) + + defer srvController.Close() + defer srvJIMM.Close() + ws := srvJIMM.Dialer.DialWebsocket(c, srvJIMM.URL) + defer ws.Close() + + p := json.RawMessage(`{"Key":"TestVal"}`) + msg := rpcproxy.Message{RequestID: 1, Type: "TestType", Request: "TestReq", Params: p} + err := ws.WriteJSON(&msg) + c.Assert(err, qt.IsNil) + resp := rpcproxy.Message{} + err = ws.ReadJSON(&resp) + c.Assert(err, qt.IsNil) + ws.Close() + <-errChan // Ensure go routines are cleaned up + c.Assert(auditLogs, qt.HasLen, 2) + expectedEvents := []*dbmodel.AuditLogEntry{{ + ID: auditLogs[0].ID, + Time: auditLogs[0].Time, + Model: "TestModelName", + ConversationId: auditLogs[0].ConversationId, + MessageId: 1, + FacadeName: "TestType", + FacadeMethod: "TestReq", + FacadeVersion: 0, + ObjectId: "", + IdentityTag: "user-testUser", + IsResponse: false, + Params: dbmodel.JSON(p), + Errors: nil, + }, { + ID: auditLogs[1].ID, + Time: auditLogs[1].Time, + Model: "TestModelName", + ConversationId: auditLogs[1].ConversationId, + MessageId: 1, + FacadeName: "", + FacadeMethod: "", + FacadeVersion: 0, + ObjectId: "", + IdentityTag: "user-testUser", + IsResponse: true, + Params: nil, + Errors: auditLogs[1].Errors, + }, + } + c.Assert(auditLogs, qt.DeepEquals, expectedEvents) + +} diff --git a/internal/rpc/apiproxy_test.go b/internal/rpcproxy/rpcproxylogin_test.go similarity index 84% rename from internal/rpc/apiproxy_test.go rename to internal/rpcproxy/rpcproxylogin_test.go index cf4fe793f..393de8c81 100644 --- a/internal/rpc/apiproxy_test.go +++ b/internal/rpcproxy/rpcproxylogin_test.go @@ -1,6 +1,6 @@ -// Copyright 2024 Canonical. +// Copyright 2025 Canonical. -package rpc_test +package rpcproxy_test import ( "context" @@ -19,24 +19,13 @@ import ( "github.com/canonical/jimm/v3/internal/dbmodel" "github.com/canonical/jimm/v3/internal/errors" "github.com/canonical/jimm/v3/internal/openfga" - "github.com/canonical/jimm/v3/internal/rpc" + "github.com/canonical/jimm/v3/internal/rpcproxy" apiparams "github.com/canonical/jimm/v3/pkg/api/params" jimmnames "github.com/canonical/jimm/v3/pkg/names" ) -type message struct { - RequestID uint64 `json:"request-id,omitempty"` - Type string `json:"type,omitempty"` - Version int `json:"version,omitempty"` - ID string `json:"id,omitempty"` - Request string `json:"request,omitempty"` - Params json.RawMessage `json:"params,omitempty"` - Error string `json:"error,omitempty"` - ErrorCode string `json:"error-code,omitempty"` - ErrorInfo map[string]interface{} `json:"error-info,omitempty"` - Response json.RawMessage `json:"response,omitempty"` -} - +// This test verifies that the ProxySockets function +// correctly handles login and authentication. func TestProxySocketsAdminFacade(t *testing.T) { c := qt.New(t) @@ -65,72 +54,72 @@ func TestProxySocketsAdminFacade(t *testing.T) { tests := []struct { about string - messageToSend message + messageToSend rpcproxy.Message authenticateEntityID string - expectedClientResponse *message - expectedControllerMessage *message + expectedClientResponse *rpcproxy.Message + expectedControllerMessage *rpcproxy.Message oauthAuthenticatorError error expectedProxyError string }{{ about: "login device call - client gets response with both user code and verification uri", - messageToSend: message{ + messageToSend: rpcproxy.Message{ RequestID: 1, Type: "Admin", Version: 4, Request: "LoginDevice", }, - expectedClientResponse: &message{ + expectedClientResponse: &rpcproxy.Message{ RequestID: 1, Response: []byte(`{"verification-uri":"http://no-such-uri.canonical.com","user-code":"test-user-code"}`), }, }, { about: "login device call, but the authenticator returns an error", - messageToSend: message{ + messageToSend: rpcproxy.Message{ RequestID: 1, Type: "Admin", Version: 4, Request: "LoginDevice", }, - expectedClientResponse: &message{ + expectedClientResponse: &rpcproxy.Message{ RequestID: 1, Error: "a silly error", }, oauthAuthenticatorError: errors.E("a silly error"), }, { about: "get device session token call - client gets response with a session token", - messageToSend: message{ + messageToSend: rpcproxy.Message{ RequestID: 1, Type: "Admin", Version: 4, Request: "GetDeviceSessionToken", }, - expectedClientResponse: &message{ + expectedClientResponse: &rpcproxy.Message{ RequestID: 1, Response: []byte(`{"session-token":"test session token"}`), }, }, { about: "get device session token call, but the authenticator returns an error", - messageToSend: message{ + messageToSend: rpcproxy.Message{ RequestID: 1, Type: "Admin", Version: 4, Request: "GetDeviceSessionToken", }, - expectedClientResponse: &message{ + expectedClientResponse: &rpcproxy.Message{ RequestID: 1, Error: "a silly error", }, oauthAuthenticatorError: errors.E("a silly error"), }, { about: "login with session token - a login message is sent to the controller", - messageToSend: message{ + messageToSend: rpcproxy.Message{ RequestID: 1, Type: "Admin", Version: 4, Request: "LoginWithSessionToken", Params: []byte(`{"client-id": "test session token"}`), }, - expectedControllerMessage: &message{ + expectedControllerMessage: &rpcproxy.Message{ RequestID: 1, Type: "Admin", Version: 3, @@ -139,14 +128,14 @@ func TestProxySocketsAdminFacade(t *testing.T) { }, }, { about: "login with session token, but authenticator returns an error", - messageToSend: message{ + messageToSend: rpcproxy.Message{ RequestID: 1, Type: "Admin", Version: 4, Request: "LoginWithSessionToken", Params: []byte(`{"client-id": "test session token"}`), }, - expectedClientResponse: &message{ + expectedClientResponse: &rpcproxy.Message{ RequestID: 1, Error: "unauthorized access", ErrorCode: "unauthorized access", @@ -154,14 +143,14 @@ func TestProxySocketsAdminFacade(t *testing.T) { oauthAuthenticatorError: errors.E(errors.CodeUnauthorized), }, { about: "login with client credentials - a login message is sent to the controller", - messageToSend: message{ + messageToSend: rpcproxy.Message{ RequestID: 1, Type: "Admin", Version: 4, Request: "LoginWithClientCredentials", Params: ccData, }, - expectedControllerMessage: &message{ + expectedControllerMessage: &rpcproxy.Message{ RequestID: 1, Type: "Admin", Version: 3, @@ -170,14 +159,14 @@ func TestProxySocketsAdminFacade(t *testing.T) { }, }, { about: "login with client credentials, but authenticator returns an error", - messageToSend: message{ + messageToSend: rpcproxy.Message{ RequestID: 1, Type: "Admin", Version: 4, Request: "LoginWithClientCredentials", Params: ccData, }, - expectedClientResponse: &message{ + expectedClientResponse: &rpcproxy.Message{ RequestID: 1, Error: "unauthorized access", ErrorCode: "unauthorized access", @@ -185,14 +174,14 @@ func TestProxySocketsAdminFacade(t *testing.T) { oauthAuthenticatorError: errors.E(errors.CodeUnauthorized), }, { about: "any other message - gets forwarded directly to the controller", - messageToSend: message{ + messageToSend: rpcproxy.Message{ RequestID: 1, Type: "Client", Version: 7, Request: "AnyMethod", Params: []byte(`{"key":"value"}`), }, - expectedControllerMessage: &message{ + expectedControllerMessage: &rpcproxy.Message{ RequestID: 1, Type: "Client", Version: 7, @@ -201,7 +190,7 @@ func TestProxySocketsAdminFacade(t *testing.T) { }, }, { about: "login with session cookie - a login message is sent to the controller", - messageToSend: message{ + messageToSend: rpcproxy.Message{ RequestID: 1, Type: "Admin", Version: 4, @@ -209,7 +198,7 @@ func TestProxySocketsAdminFacade(t *testing.T) { Params: ccData, }, authenticateEntityID: "alice@wonderland.io", - expectedControllerMessage: &message{ + expectedControllerMessage: &rpcproxy.Message{ RequestID: 1, Type: "Admin", Version: 3, @@ -218,14 +207,14 @@ func TestProxySocketsAdminFacade(t *testing.T) { }, }, { about: "login with session cookie - but there was no identity id in the cookie", - messageToSend: message{ + messageToSend: rpcproxy.Message{ RequestID: 1, Type: "Admin", Version: 4, Request: "LoginWithSessionCookie", Params: ccData, }, - expectedClientResponse: &message{ + expectedClientResponse: &rpcproxy.Message{ RequestID: 1, Error: "unauthorized access", ErrorCode: "unauthorized access", @@ -233,7 +222,7 @@ func TestProxySocketsAdminFacade(t *testing.T) { oauthAuthenticatorError: errors.E(errors.CodeUnauthorized), }, { about: "connection to controller fails", - expectedClientResponse: &message{ + expectedClientResponse: &rpcproxy.Message{ Error: "controller connection error", }, expectedProxyError: "failed to connect to controller: controller connection error", @@ -256,14 +245,14 @@ func TestProxySocketsAdminFacade(t *testing.T) { err: test.oauthAuthenticatorError, } - helpers := rpc.ProxyHelpers{ + helpers := rpcproxy.ProxyHelpers{ ConnClient: clientWebsocket, TokenGen: &mockTokenGenerator{}, - ConnectController: func(ctx context.Context) (rpc.WebsocketConnectionWithMetadata, error) { + ConnectController: func(ctx context.Context) (rpcproxy.WebsocketConnectionWithMetadata, error) { if proxyError { - return rpc.WebsocketConnectionWithMetadata{}, goerr.New("controller connection error") + return rpcproxy.WebsocketConnectionWithMetadata{}, goerr.New("controller connection error") } - return rpc.WebsocketConnectionWithMetadata{ + return rpcproxy.WebsocketConnectionWithMetadata{ Conn: controllerWebsocket, ModelName: "test model", ControllerUUID: uuid.NewString(), @@ -277,7 +266,7 @@ func TestProxySocketsAdminFacade(t *testing.T) { wg.Add(1) go func() { defer wg.Done() - err = rpc.ProxySockets(ctx, helpers) + err = rpcproxy.ProxySockets(ctx, helpers) if proxyError { c.Assert(err, qt.ErrorMatches, test.expectedProxyError) } else { diff --git a/internal/rpc/streamproxy.go b/internal/streamproxy/streamproxy.go similarity index 73% rename from internal/rpc/streamproxy.go rename to internal/streamproxy/streamproxy.go index 5da96206f..1fbd38732 100644 --- a/internal/rpc/streamproxy.go +++ b/internal/streamproxy/streamproxy.go @@ -1,10 +1,12 @@ -// Copyright 2024 Canonical. +// Copyright 2025 Canonical. -package rpc +package streamproxy import ( "context" + "encoding/json" + "github.com/gorilla/websocket" "github.com/juju/juju/api/base" "github.com/juju/zaputil/zapctx" "go.uber.org/zap" @@ -46,3 +48,14 @@ func proxy(src base.Stream, dst base.Stream) error { } } } + +func unexpectedReadError(err error) bool { + if websocket.IsUnexpectedCloseError(err, + websocket.CloseNormalClosure, + websocket.CloseNoStatusReceived, + websocket.CloseAbnormalClosure) { + return true + } + _, unmarshalError := err.(*json.InvalidUnmarshalError) + return unmarshalError +} diff --git a/internal/rpc/streamproxy_test.go b/internal/streamproxy/streamproxy_test.go similarity index 61% rename from internal/rpc/streamproxy_test.go rename to internal/streamproxy/streamproxy_test.go index 2385d2109..890b2af4d 100644 --- a/internal/rpc/streamproxy_test.go +++ b/internal/streamproxy/streamproxy_test.go @@ -1,5 +1,6 @@ -// Copyright 2024 Canonical. -package rpc_test +// Copyright 2025 Canonical. + +package streamproxy_test import ( "context" @@ -11,7 +12,8 @@ import ( qt "github.com/frankban/quicktest" "github.com/gorilla/websocket" - "github.com/canonical/jimm/v3/internal/rpc" + "github.com/canonical/jimm/v3/internal/streamproxy" + "github.com/canonical/jimm/v3/internal/testutils/rpctest" ) func echoSingleMessage(c *websocket.Conn) error { @@ -54,18 +56,16 @@ func TestStreamProxy(t *testing.T) { ctx := context.Background() doneChan := make(chan error) - srvController := newServer(echoSingleMessage) - srvJIMM := newServer(func(connClient *websocket.Conn) error { - connController, err := srvController.dialer.DialWebsocket(ctx, srvController.URL, nil) - c.Assert(err, qt.IsNil) - rpc.ProxyStreams(ctx, connClient, connController) + srvController := rpctest.NewServer(echoSingleMessage) + srvJIMM := rpctest.NewServer(func(connClient *websocket.Conn) error { + connController := srvController.Dialer.DialWebsocket(c, srvController.URL) + streamproxy.ProxyStreams(ctx, connClient, connController) doneChan <- nil return nil }) defer srvController.Close() defer srvJIMM.Close() - ws, err := srvJIMM.dialer.DialWebsocket(ctx, srvJIMM.URL, nil) - c.Assert(err, qt.IsNil) + ws := srvJIMM.Dialer.DialWebsocket(c, srvJIMM.URL) defer ws.Close() verifyEcho(c, ws, "") @@ -79,18 +79,16 @@ func TestStreamProxyStoppedController(t *testing.T) { ctx := context.Background() doneChan := make(chan error) - srvController := newServer(func(c *websocket.Conn) error { return errors.New("stopped") }) - srvJIMM := newServer(func(connClient *websocket.Conn) error { - connController, err := srvController.dialer.DialWebsocket(ctx, srvController.URL, nil) - c.Assert(err, qt.IsNil) - rpc.ProxyStreams(ctx, connClient, connController) + srvController := rpctest.NewServer(func(c *websocket.Conn) error { return errors.New("stopped") }) + srvJIMM := rpctest.NewServer(func(connClient *websocket.Conn) error { + connController := srvController.Dialer.DialWebsocket(c, srvController.URL) + streamproxy.ProxyStreams(ctx, connClient, connController) doneChan <- nil return nil }) defer srvController.Close() defer srvJIMM.Close() - ws, err := srvJIMM.dialer.DialWebsocket(ctx, srvJIMM.URL, nil) - c.Assert(err, qt.IsNil) + ws := srvJIMM.Dialer.DialWebsocket(c, srvJIMM.URL) defer ws.Close() verifyEcho(c, ws, ".*abnormal closure.*") @@ -104,18 +102,16 @@ func TestStreamProxyStoppedMidwayController(t *testing.T) { ctx := context.Background() doneChan := make(chan error) - srvController := newServer(echoSingleMessage) - srvJIMM := newServer(func(connClient *websocket.Conn) error { - connController, err := srvController.dialer.DialWebsocket(ctx, srvController.URL, nil) - c.Assert(err, qt.IsNil) - rpc.ProxyStreams(ctx, connClient, connController) + srvController := rpctest.NewServer(echoSingleMessage) + srvJIMM := rpctest.NewServer(func(connClient *websocket.Conn) error { + connController := srvController.Dialer.DialWebsocket(c, srvController.URL) + streamproxy.ProxyStreams(ctx, connClient, connController) doneChan <- nil return nil }) defer srvController.Close() defer srvJIMM.Close() - ws, err := srvJIMM.dialer.DialWebsocket(ctx, srvJIMM.URL, nil) - c.Assert(err, qt.IsNil) + ws := srvJIMM.Dialer.DialWebsocket(c, srvJIMM.URL) defer ws.Close() verifyEcho(c, ws, "") diff --git a/internal/testutils/rpctest/server.go b/internal/testutils/rpctest/server.go new file mode 100644 index 000000000..92ec1b7c4 --- /dev/null +++ b/internal/testutils/rpctest/server.go @@ -0,0 +1,96 @@ +// Copyright 2025 Canonical. + +package rpctest + +import ( + "context" + "crypto/tls" + "crypto/x509" + "net/http" + "net/http/httptest" + "strings" + "time" + + qt "github.com/frankban/quicktest" + "github.com/gorilla/websocket" +) + +type testDialer struct { + tlsConfig *tls.Config +} + +// Dial establishes a new client RPC connection to the given URL. +func (d *testDialer) DialWebsocket(c *qt.C, url string) *websocket.Conn { + dialer := websocket.Dialer{ + TLSClientConfig: d.tlsConfig, + } + conn, resp, err := dialer.DialContext(context.Background(), url, nil) + c.Assert(err, qt.IsNil) + defer resp.Body.Close() + return conn +} + +type Server struct { + *httptest.Server + + URL string + Dialer *testDialer +} + +func NewServer(f func(*websocket.Conn) error) *Server { + var srv Server + srv.Server = httptest.NewTLSServer(HandleWS(f)) + srv.URL = "ws" + strings.TrimPrefix(srv.Server.URL, "http") + cp := x509.NewCertPool() + cp.AddCert(srv.Certificate()) + srv.Dialer = &testDialer{ + tlsConfig: &tls.Config{ + RootCAs: cp, + MinVersion: tls.VersionTLS12, + }, + } + return &srv +} + +func HandleWS(f func(*websocket.Conn) error) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + var u websocket.Upgrader + c, err := u.Upgrade(w, req, nil) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + defer c.Close() + err = f(c) + var cm []byte + closeError, isCloseError := err.(*websocket.CloseError) + switch { + case err == nil: + cm = websocket.FormatCloseMessage(websocket.CloseNormalClosure, "") + case isCloseError: + cm = websocket.FormatCloseMessage(closeError.Code, closeError.Text) + default: + cm = websocket.FormatCloseMessage(websocket.CloseInternalServerErr, err.Error()) + } + _ = c.WriteControl(websocket.CloseMessage, cm, time.Time{}) + + }) +} + +func Echo(c *websocket.Conn) error { + for { + msg := make(map[string]interface{}) + if err := c.ReadJSON(&msg); err != nil { + return err + } + delete(msg, "type") + delete(msg, "version") + delete(msg, "id") + delete(msg, "request") + msg["response"] = msg["params"] + delete(msg, "params") + if err := c.WriteJSON(msg); err != nil { + return err + } + } +} From c778fc6fd9afa275767520f758d58a5d0fc79186 Mon Sep 17 00:00:00 2001 From: Kian Parvin Date: Tue, 14 Jan 2025 09:56:22 +0200 Subject: [PATCH 2/2] chore: remove unused isRequest func --- internal/rpcproxy/message.go | 5 ----- 1 file changed, 5 deletions(-) diff --git a/internal/rpcproxy/message.go b/internal/rpcproxy/message.go index 9fad5e478..71ee3b151 100644 --- a/internal/rpcproxy/message.go +++ b/internal/rpcproxy/message.go @@ -23,8 +23,3 @@ type message struct { ErrorInfo map[string]interface{} `json:"error-info,omitempty"` Response json.RawMessage `json:"response,omitempty"` } - -// isRequest returns whether the message is a request -func (m message) isRequest() bool { - return m.Type != "" && m.Request != "" -}