diff --git a/internal/jujuapi/streamproxy.go b/internal/jujuapi/streamproxy.go index 5500c849a..fcbc3a9b0 100644 --- a/internal/jujuapi/streamproxy.go +++ b/internal/jujuapi/streamproxy.go @@ -17,7 +17,7 @@ import ( "github.com/canonical/jimm/v3/internal/errors" "github.com/canonical/jimm/v3/internal/jimmhttp" "github.com/canonical/jimm/v3/internal/openfga" - jimmRPC "github.com/canonical/jimm/v3/internal/rpc" + "github.com/canonical/jimm/v3/internal/streamproxy" ) // A streamProxier serves the the /log endpoint by proxying @@ -103,7 +103,7 @@ func (s streamProxier) ServeWS(ctx context.Context, clientConn *websocket.Conn) return } - jimmRPC.ProxyStreams(ctx, clientConn, controllerStream) + streamproxy.ProxyStreams(ctx, clientConn, controllerStream) } func checkPermission(ctx context.Context, path string, u *openfga.User, mt names.ModelTag) (bool, error) { diff --git a/internal/jujuapi/websocket.go b/internal/jujuapi/websocket.go index 981f83abc..325a0ed34 100644 --- a/internal/jujuapi/websocket.go +++ b/internal/jujuapi/websocket.go @@ -24,6 +24,7 @@ import ( "github.com/canonical/jimm/v3/internal/jimm/jujuauth" "github.com/canonical/jimm/v3/internal/jimmhttp" jimmRPC "github.com/canonical/jimm/v3/internal/rpc" + "github.com/canonical/jimm/v3/internal/rpcproxy" ) const ( @@ -177,7 +178,7 @@ func (s apiProxier) ServeWS(ctx context.Context, clientConn *websocket.Conn) { connectionFunc := controllerConnectionFunc(s, &jwtGenerator) zapctx.Debug(ctx, "Starting proxier") auditLogger := s.jimm.AuditLogManager().AddAuditLogEntry - proxyHelpers := jimmRPC.ProxyHelpers{ + proxyHelpers := rpcproxy.ProxyHelpers{ ConnClient: clientConn, TokenGen: &jwtGenerator, ConnectController: connectionFunc, @@ -185,22 +186,22 @@ func (s apiProxier) ServeWS(ctx context.Context, clientConn *websocket.Conn) { LoginService: s.jimm.LoginManager(), AuthenticatedIdentityID: auth.SessionIdentityFromContext(ctx), } - if err := jimmRPC.ProxySockets(ctx, proxyHelpers); err != nil { + if err := rpcproxy.ProxySockets(ctx, proxyHelpers); err != nil { zapctx.Error(ctx, "failed to start jimm model proxy", zap.Error(err)) } } // controllerConnectionFunc returns a function that will be used to // connect to a controller when a client makes a request. -func controllerConnectionFunc(s apiProxier, jwtGenerator *jujuauth.TokenGenerator) func(context.Context) (jimmRPC.WebsocketConnectionWithMetadata, error) { - return func(ctx context.Context) (jimmRPC.WebsocketConnectionWithMetadata, error) { +func controllerConnectionFunc(s apiProxier, jwtGenerator *jujuauth.TokenGenerator) func(context.Context) (rpcproxy.WebsocketConnectionWithMetadata, error) { + return func(ctx context.Context) (rpcproxy.WebsocketConnectionWithMetadata, error) { const op = errors.Op("proxy.controllerConnectionFunc") path := jimmhttp.PathElementFromContext(ctx, "path") zapctx.Debug(ctx, "grabbing model info from path", zap.String("path", path)) uuid, finalPath, err := modelInfoFromPath(path) if err != nil { zapctx.Error(ctx, "error parsing path", zap.Error(err)) - return jimmRPC.WebsocketConnectionWithMetadata{}, errors.E(op, err) + return rpcproxy.WebsocketConnectionWithMetadata{}, errors.E(op, err) } m := dbmodel.Model{ UUID: sql.NullString{ @@ -210,7 +211,7 @@ func controllerConnectionFunc(s apiProxier, jwtGenerator *jujuauth.TokenGenerato } if err := s.jimm.Database.GetModel(context.Background(), &m); err != nil { zapctx.Error(ctx, "failed to find model", zap.String("uuid", uuid), zap.Error(err)) - return jimmRPC.WebsocketConnectionWithMetadata{}, errors.E(err, errors.CodeNotFound) + return rpcproxy.WebsocketConnectionWithMetadata{}, errors.E(err, errors.CodeNotFound) } jwtGenerator.SetTags(m.ResourceTag(), m.Controller.ResourceTag()) mt := m.ResourceTag() @@ -218,10 +219,10 @@ func controllerConnectionFunc(s apiProxier, jwtGenerator *jujuauth.TokenGenerato controllerConn, err := jimmRPC.Dial(ctx, &m.Controller, mt, finalPath, nil) if err != nil { zapctx.Error(ctx, "cannot dial controller", zap.String("controller", m.Controller.Name), zap.Error(err)) - return jimmRPC.WebsocketConnectionWithMetadata{}, err + return rpcproxy.WebsocketConnectionWithMetadata{}, err } fullModelName := m.Controller.Name + "/" + m.Name - return jimmRPC.WebsocketConnectionWithMetadata{ + return rpcproxy.WebsocketConnectionWithMetadata{ Conn: controllerConn, ControllerUUID: m.Controller.UUID, ModelName: fullModelName, diff --git a/internal/rpc/client.go b/internal/rpc/client.go index 61ec4f646..059ef5db7 100644 --- a/internal/rpc/client.go +++ b/internal/rpc/client.go @@ -1,4 +1,4 @@ -// Copyright 2024 Canonical. +// Copyright 2025 Canonical. package rpc @@ -203,16 +203,6 @@ func (c *Client) Call(ctx context.Context, facade string, version int, id, metho select { case <-ch: - permissionsRequired, err := checkPermissionsRequired(ctx, *respMsg) - if err != nil { - return err - } - if permissionsRequired != nil { - return &Error{ - Code: PermissionCheckRequiredErrorCode, - Info: permissionsRequired, - } - } if (*respMsg).Error != "" { return &Error{ Message: (*respMsg).Error, diff --git a/internal/rpc/client_test.go b/internal/rpc/client_test.go index 7e94de97e..258da821c 100644 --- a/internal/rpc/client_test.go +++ b/internal/rpc/client_test.go @@ -1,4 +1,4 @@ -// Copyright 2024 Canonical. +// Copyright 2025 Canonical. package rpc_test @@ -6,28 +6,24 @@ import ( "context" "crypto/tls" "crypto/x509" - "encoding/json" "net" "net/http" "net/http/httptest" "strings" "testing" - "time" qt "github.com/frankban/quicktest" "github.com/gorilla/websocket" - "github.com/juju/names/v5" - "github.com/canonical/jimm/v3/internal/dbmodel" "github.com/canonical/jimm/v3/internal/errors" - "github.com/canonical/jimm/v3/internal/openfga" "github.com/canonical/jimm/v3/internal/rpc" + "github.com/canonical/jimm/v3/internal/testutils/rpctest" ) func TestDialError(t *testing.T) { c := qt.New(t) - srv := newServer(echo) + srv := newServer(rpctest.Echo) defer srv.Close() d := *srv.dialer d.TLSConfig = nil @@ -38,7 +34,7 @@ func TestDialError(t *testing.T) { func TestDial(t *testing.T) { c := qt.New(t) - srv := newServer(echo) + srv := newServer(rpctest.Echo) defer srv.Close() conn, err := srv.dialer.Dial(context.Background(), srv.URL, nil) c.Assert(err, qt.IsNil) @@ -48,7 +44,7 @@ func TestDial(t *testing.T) { func TestBasicDial(t *testing.T) { c := qt.New(t) - srv := newServer(echo) + srv := newServer(rpctest.Echo) defer srv.Close() conn, err := srv.dialer.DialWebsocket(context.Background(), srv.URL, nil) c.Assert(err, qt.IsNil) @@ -58,7 +54,7 @@ func TestBasicDial(t *testing.T) { func TestCallSuccess(t *testing.T) { c := qt.New(t) - srv := newServer(echo) + srv := newServer(rpctest.Echo) defer srv.Close() conn, err := srv.dialer.Dial(context.Background(), srv.URL, nil) c.Assert(err, qt.IsNil) @@ -76,7 +72,7 @@ func TestCallSuccess(t *testing.T) { func TestCallCanceledContext(t *testing.T) { c := qt.New(t) - srv := newServer(echo) + srv := newServer(rpctest.Echo) defer srv.Close() conn, err := srv.dialer.Dial(context.Background(), srv.URL, nil) c.Assert(err, qt.IsNil) @@ -137,7 +133,7 @@ func TestCallErrorResponse(t *testing.T) { if err := conn.WriteJSON(resp); err != nil { return err } - return echo(conn) + return rpctest.Echo(conn) }) defer srv.Close() conn, err := srv.dialer.Dial(context.Background(), srv.URL, nil) @@ -181,7 +177,7 @@ func TestClientReceiveRequest(t *testing.T) { if err := conn.WriteJSON(req2); err != nil { return err } - return echo(conn) + return rpctest.Echo(conn) }) defer srv.Close() conn, err := srv.dialer.Dial(context.Background(), srv.URL, nil) @@ -211,7 +207,7 @@ func TestClientReceiveInvalidMessage(t *testing.T) { if err := conn.WriteJSON(struct{}{}); err != nil { return err } - return echo(conn) + return rpctest.Echo(conn) }) defer srv.Close() conn, err := srv.dialer.Dial(context.Background(), srv.URL, nil) @@ -224,268 +220,6 @@ func TestClientReceiveInvalidMessage(t *testing.T) { c.Check(res, qt.Equals, "") } -type testTokenGenerator struct{} - -func (p *testTokenGenerator) MakeLoginToken(ctx context.Context, user *openfga.User) ([]byte, error) { - return nil, nil -} - -func (p *testTokenGenerator) MakeToken(ctx context.Context, permissionMap map[string]interface{}) ([]byte, error) { - return nil, nil -} - -func (p *testTokenGenerator) SetTags(names.ModelTag, names.ControllerTag) { -} - -func (p *testTokenGenerator) GetUser() names.UserTag { - return names.NewUserTag("testUser") -} - -func TestProxySockets(t *testing.T) { - c := qt.New(t) - ctx := context.Background() - - srvController := newServer(echo) - - errChan := make(chan error) - srvJIMM := newServer(func(connClient *websocket.Conn) error { - testTokenGen := testTokenGenerator{} - f := func(context.Context) (rpc.WebsocketConnectionWithMetadata, error) { - connController, err := srvController.dialer.DialWebsocket(ctx, srvController.URL, nil) - c.Check(err, qt.IsNil) - return rpc.WebsocketConnectionWithMetadata{ - Conn: connController, - ModelName: "TestName", - }, nil - } - auditLogger := func(ale *dbmodel.AuditLogEntry) {} - proxyHelpers := rpc.ProxyHelpers{ - ConnClient: connClient, - TokenGen: &testTokenGen, - ConnectController: f, - AuditLog: auditLogger, - LoginService: &mockLoginService{}, - } - err := rpc.ProxySockets(ctx, proxyHelpers) - c.Check(err, qt.IsNil) - errChan <- err - return err - }) - - defer srvController.Close() - defer srvJIMM.Close() - ws, err := srvJIMM.dialer.DialWebsocket(ctx, srvJIMM.URL, nil) - c.Assert(err, qt.IsNil) - defer ws.Close() - - p := json.RawMessage(`{"Key":"TestVal"}`) - msg := rpc.Message{RequestID: 1, Type: "TestType", Request: "TestReq", Params: p} - err = ws.WriteJSON(&msg) - c.Assert(err, qt.IsNil) - resp := rpc.Message{} - receiveChan := make(chan error) - go func() { - receiveChan <- ws.ReadJSON(&resp) - }() - select { - case err := <-receiveChan: - c.Assert(err, qt.IsNil) - case <-time.After(5 * time.Second): - c.Logf("took too long to read response") - c.FailNow() - } - c.Assert(resp.Response, qt.DeepEquals, msg.Params) - ws.Close() - <-errChan // Ensure go routines are cleaned up -} - -func TestProxySocketsControllerConnectionFails(t *testing.T) { - c := qt.New(t) - ctx := context.Background() - - srvController := newServer(echo) - - var connController *websocket.Conn - errChan := make(chan error) - srvJIMM := newServer(func(connClient *websocket.Conn) error { - testTokenGen := testTokenGenerator{} - f := func(context.Context) (rpc.WebsocketConnectionWithMetadata, error) { - var err error - connController, err = srvController.dialer.DialWebsocket(ctx, srvController.URL, nil) - c.Check(err, qt.IsNil) - return rpc.WebsocketConnectionWithMetadata{ - Conn: connController, - ModelName: "TestName", - }, nil - } - auditLogger := func(ale *dbmodel.AuditLogEntry) {} - proxyHelpers := rpc.ProxyHelpers{ - ConnClient: connClient, - TokenGen: &testTokenGen, - ConnectController: f, - AuditLog: auditLogger, - LoginService: &mockLoginService{}, - } - err := rpc.ProxySockets(ctx, proxyHelpers) - c.Check(err, qt.IsNil) - errChan <- err - return err - }) - - defer srvController.Close() - defer srvJIMM.Close() - ws, err := srvJIMM.dialer.DialWebsocket(ctx, srvJIMM.URL, nil) - c.Assert(err, qt.IsNil) - defer ws.Close() - - p := json.RawMessage(`{"Key":"TestVal"}`) - msg := rpc.Message{RequestID: 1, Type: "TestType", Request: "TestReq", Params: p} - err = ws.WriteJSON(&msg) - c.Assert(err, qt.IsNil) - resp := rpc.Message{} - receiveChan := make(chan error) - go func() { - receiveChan <- ws.ReadJSON(&resp) - }() - select { - case err := <-receiveChan: - c.Assert(err, qt.IsNil) - case <-time.After(5 * time.Second): - c.Logf("took too long to read response") - c.FailNow() - } - c.Assert(resp.Response, qt.DeepEquals, msg.Params) - - // Now close the connection to the controller and ensure the model proxy is cleaned up. - connController.Close() - <-errChan // Ensure go routines are cleaned up -} - -func TestCancelProxySockets(t *testing.T) { - c := qt.New(t) - - ctx, cancel := context.WithCancel(context.Background()) - - srvController := newServer(echo) - - errChan := make(chan error) - srvJIMM := newServer(func(connClient *websocket.Conn) error { - testTokenGen := testTokenGenerator{} - f := func(context.Context) (rpc.WebsocketConnectionWithMetadata, error) { - connController, err := srvController.dialer.DialWebsocket(ctx, srvController.URL, nil) - c.Check(err, qt.IsNil) - return rpc.WebsocketConnectionWithMetadata{ - Conn: connController, - ModelName: "TestName", - }, nil - } - auditLogger := func(ale *dbmodel.AuditLogEntry) {} - proxyHelpers := rpc.ProxyHelpers{ - ConnClient: connClient, - TokenGen: &testTokenGen, - ConnectController: f, - AuditLog: auditLogger, - LoginService: &mockLoginService{}, - } - err := rpc.ProxySockets(ctx, proxyHelpers) - c.Check(err, qt.ErrorMatches, "Context cancelled") - errChan <- err - return err - }) - - defer srvController.Close() - defer srvJIMM.Close() - ws, err := srvJIMM.dialer.DialWebsocket(ctx, srvJIMM.URL, nil) - c.Assert(err, qt.IsNil) - defer ws.Close() - cancel() - <-errChan -} - -func TestProxySocketsAuditLogs(t *testing.T) { - c := qt.New(t) - - ctx := context.Background() - - srvController := newServer(echo) - auditLogs := make([]*dbmodel.AuditLogEntry, 0) - - errChan := make(chan error) - srvJIMM := newServer(func(connClient *websocket.Conn) error { - defer connClient.Close() - testTokenGen := testTokenGenerator{} - f := func(context.Context) (rpc.WebsocketConnectionWithMetadata, error) { - connController, err := srvController.dialer.DialWebsocket(ctx, srvController.URL, nil) - c.Check(err, qt.IsNil) - return rpc.WebsocketConnectionWithMetadata{ - Conn: connController, - ModelName: "TestModelName", - }, nil - } - auditLogger := func(ale *dbmodel.AuditLogEntry) { auditLogs = append(auditLogs, ale) } - proxyHelpers := rpc.ProxyHelpers{ - ConnClient: connClient, - TokenGen: &testTokenGen, - ConnectController: f, - AuditLog: auditLogger, - LoginService: &mockLoginService{}, - } - err := rpc.ProxySockets(ctx, proxyHelpers) - c.Check(err, qt.IsNil) - errChan <- err - return err - }) - - defer srvController.Close() - defer srvJIMM.Close() - ws, err := srvJIMM.dialer.DialWebsocket(ctx, srvJIMM.URL, nil) - c.Assert(err, qt.IsNil) - defer ws.Close() - - p := json.RawMessage(`{"Key":"TestVal"}`) - msg := rpc.Message{RequestID: 1, Type: "TestType", Request: "TestReq", Params: p} - err = ws.WriteJSON(&msg) - c.Assert(err, qt.IsNil) - resp := rpc.Message{} - err = ws.ReadJSON(&resp) - c.Assert(err, qt.IsNil) - ws.Close() - <-errChan // Ensure go routines are cleaned up - c.Assert(auditLogs, qt.HasLen, 2) - expectedEvents := []*dbmodel.AuditLogEntry{{ - ID: auditLogs[0].ID, - Time: auditLogs[0].Time, - Model: "TestModelName", - ConversationId: auditLogs[0].ConversationId, - MessageId: 1, - FacadeName: "TestType", - FacadeMethod: "TestReq", - FacadeVersion: 0, - ObjectId: "", - IdentityTag: "user-testUser", - IsResponse: false, - Params: dbmodel.JSON(p), - Errors: nil, - }, { - ID: auditLogs[1].ID, - Time: auditLogs[1].Time, - Model: "TestModelName", - ConversationId: auditLogs[1].ConversationId, - MessageId: 1, - FacadeName: "", - FacadeMethod: "", - FacadeVersion: 0, - ObjectId: "", - IdentityTag: "user-testUser", - IsResponse: true, - Params: nil, - Errors: auditLogs[1].Errors, - }, - } - c.Assert(auditLogs, qt.DeepEquals, expectedEvents) - -} - type server struct { *httptest.Server @@ -495,7 +229,7 @@ type server struct { func newServer(f func(*websocket.Conn) error) *server { var srv server - srv.Server = httptest.NewTLSServer(handleWS(f)) + srv.Server = httptest.NewTLSServer(rpctest.HandleWS(f)) srv.URL = "ws" + strings.TrimPrefix(srv.Server.URL, "http") cp := x509.NewCertPool() cp.AddCert(srv.Certificate()) @@ -513,7 +247,7 @@ func newIPv6Server(f func(*websocket.Conn) error) *server { l, _ := net.Listen("tcp", "[::1]:0") server := httptest.Server{ Listener: l, - Config: &http.Server{Handler: handleWS(f)}, //nolint:gosec + Config: &http.Server{Handler: rpctest.HandleWS(f)}, //nolint:gosec } server.StartTLS() srv.Server = &server @@ -528,46 +262,3 @@ func newIPv6Server(f func(*websocket.Conn) error) *server { } return &srv } - -func handleWS(f func(*websocket.Conn) error) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - var u websocket.Upgrader - c, err := u.Upgrade(w, req, nil) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - defer c.Close() - err = f(c) - var cm []byte - closeError, isCloseError := err.(*websocket.CloseError) - switch { - case err == nil: - cm = websocket.FormatCloseMessage(websocket.CloseNormalClosure, "") - case isCloseError: - cm = websocket.FormatCloseMessage(closeError.Code, closeError.Text) - default: - cm = websocket.FormatCloseMessage(websocket.CloseInternalServerErr, err.Error()) - } - _ = c.WriteControl(websocket.CloseMessage, cm, time.Time{}) - - }) -} - -func echo(c *websocket.Conn) error { - for { - msg := make(map[string]interface{}) - if err := c.ReadJSON(&msg); err != nil { - return err - } - delete(msg, "type") - delete(msg, "version") - delete(msg, "id") - delete(msg, "request") - msg["response"] = msg["params"] - delete(msg, "params") - if err := c.WriteJSON(msg); err != nil { - return err - } - } -} diff --git a/internal/rpc/dial _test.go b/internal/rpc/dial _test.go index bae202678..ac8d37a09 100644 --- a/internal/rpc/dial _test.go +++ b/internal/rpc/dial _test.go @@ -1,4 +1,4 @@ -// Copyright 2024 Canonical. +// Copyright 2025 Canonical. package rpc_test @@ -15,12 +15,13 @@ import ( "github.com/canonical/jimm/v3/internal/dbmodel" "github.com/canonical/jimm/v3/internal/rpc" + "github.com/canonical/jimm/v3/internal/testutils/rpctest" ) func TestDialIPv4(t *testing.T) { c := qt.New(t) ctx := context.Background() - fakeController := newServer(echo) + fakeController := newServer(rpctest.Echo) defer fakeController.Close() controller := dbmodel.Controller{} pemData := pem.EncodeToMemory(&pem.Block{ @@ -44,7 +45,7 @@ func TestDialIPv4(t *testing.T) { func TestDialIPv6(t *testing.T) { c := qt.New(t) ctx := context.Background() - fakeController := newIPv6Server(echo) + fakeController := newIPv6Server(rpctest.Echo) defer fakeController.Close() controller := dbmodel.Controller{} pemData := pem.EncodeToMemory(&pem.Block{ diff --git a/internal/rpc/export_test.go b/internal/rpc/export_test.go index a50d4d277..583c041e9 100644 --- a/internal/rpc/export_test.go +++ b/internal/rpc/export_test.go @@ -1,4 +1,5 @@ -// Copyright 2024 Canonical. +// Copyright 2025 Canonical. + package rpc -type Message message +type Message = message diff --git a/internal/rpc/rpc.go b/internal/rpc/rpc.go index 4090942a3..b7e7fc4d2 100644 --- a/internal/rpc/rpc.go +++ b/internal/rpc/rpc.go @@ -1,4 +1,4 @@ -// Copyright 2024 Canonical. +// Copyright 2025 Canonical. // Package rpc implements the juju RPC protocol. The main difference // between this implementation and the implementation in @@ -9,14 +9,12 @@ package rpc import ( "encoding/json" - "time" ) // A message encodes a single message sent, or received, over an RPC // connection. It contains the union of fields in a request or response // message. type message struct { - start time.Time RequestID uint64 `json:"request-id,omitempty"` Type string `json:"type,omitempty"` Version int `json:"version,omitempty"` diff --git a/internal/rpcproxy/export_test.go b/internal/rpcproxy/export_test.go new file mode 100644 index 000000000..a58cb9249 --- /dev/null +++ b/internal/rpcproxy/export_test.go @@ -0,0 +1,5 @@ +// Copyright 2025 Canonical. + +package rpcproxy + +type Message = message diff --git a/internal/rpcproxy/message.go b/internal/rpcproxy/message.go new file mode 100644 index 000000000..9fad5e478 --- /dev/null +++ b/internal/rpcproxy/message.go @@ -0,0 +1,30 @@ +// Copyright 2025 Canonical. + +package rpcproxy + +import ( + "encoding/json" + "time" +) + +// A message encodes a single message sent, or received, over an RPC +// connection. It contains the union of fields in a request or response +// message. +type message struct { + start time.Time + RequestID uint64 `json:"request-id,omitempty"` + Type string `json:"type,omitempty"` + Version int `json:"version,omitempty"` + ID string `json:"id,omitempty"` + Request string `json:"request,omitempty"` + Params json.RawMessage `json:"params,omitempty"` + Error string `json:"error,omitempty"` + ErrorCode string `json:"error-code,omitempty"` + ErrorInfo map[string]interface{} `json:"error-info,omitempty"` + Response json.RawMessage `json:"response,omitempty"` +} + +// isRequest returns whether the message is a request +func (m message) isRequest() bool { + return m.Type != "" && m.Request != "" +} diff --git a/internal/rpc/apiproxy.go b/internal/rpcproxy/rpcproxy.go similarity index 98% rename from internal/rpc/apiproxy.go rename to internal/rpcproxy/rpcproxy.go index f912d4138..e3375f0d4 100644 --- a/internal/rpc/apiproxy.go +++ b/internal/rpcproxy/rpcproxy.go @@ -1,5 +1,9 @@ -// Copyright 2024 Canonical. -package rpc +// Copyright 2025 Canonical. + +// Package rpcproxy implements a proxy for Juju's RPC messages. +// The rpcproxy is used to proxy messages between jimm and model facades +// on Juju controllers while still acting as an authorisation and routing layer. +package rpcproxy import ( "context" diff --git a/internal/rpcproxy/rpcproxy_test.go b/internal/rpcproxy/rpcproxy_test.go new file mode 100644 index 000000000..03030e710 --- /dev/null +++ b/internal/rpcproxy/rpcproxy_test.go @@ -0,0 +1,274 @@ +// Copyright 2025 Canonical. + +package rpcproxy_test + +import ( + "context" + "encoding/json" + "testing" + "time" + + qt "github.com/frankban/quicktest" + "github.com/gorilla/websocket" + "github.com/juju/names/v5" + + "github.com/canonical/jimm/v3/internal/dbmodel" + "github.com/canonical/jimm/v3/internal/openfga" + "github.com/canonical/jimm/v3/internal/rpcproxy" + "github.com/canonical/jimm/v3/internal/testutils/rpctest" +) + +type testTokenGenerator struct{} + +func (p *testTokenGenerator) MakeLoginToken(ctx context.Context, user *openfga.User) ([]byte, error) { + return nil, nil +} + +func (p *testTokenGenerator) MakeToken(ctx context.Context, permissionMap map[string]interface{}) ([]byte, error) { + return nil, nil +} + +func (p *testTokenGenerator) SetTags(names.ModelTag, names.ControllerTag) { +} + +func (p *testTokenGenerator) GetUser() names.UserTag { + return names.NewUserTag("testUser") +} + +func TestProxySockets(t *testing.T) { + c := qt.New(t) + ctx := context.Background() + + srvController := rpctest.NewServer(rpctest.Echo) + + errChan := make(chan error) + srvJIMM := rpctest.NewServer(func(connClient *websocket.Conn) error { + testTokenGen := testTokenGenerator{} + f := func(context.Context) (rpcproxy.WebsocketConnectionWithMetadata, error) { + connController := srvController.Dialer.DialWebsocket(c, srvController.URL) + return rpcproxy.WebsocketConnectionWithMetadata{ + Conn: connController, + ModelName: "TestName", + }, nil + } + auditLogger := func(ale *dbmodel.AuditLogEntry) {} + proxyHelpers := rpcproxy.ProxyHelpers{ + ConnClient: connClient, + TokenGen: &testTokenGen, + ConnectController: f, + AuditLog: auditLogger, + LoginService: &mockLoginService{}, + } + err := rpcproxy.ProxySockets(ctx, proxyHelpers) + c.Check(err, qt.IsNil) + errChan <- err + return err + }) + + defer srvController.Close() + defer srvJIMM.Close() + ws := srvJIMM.Dialer.DialWebsocket(c, srvJIMM.URL) + defer ws.Close() + + p := json.RawMessage(`{"Key":"TestVal"}`) + msg := rpcproxy.Message{RequestID: 1, Type: "TestType", Request: "TestReq", Params: p} + err := ws.WriteJSON(&msg) + c.Assert(err, qt.IsNil) + resp := rpcproxy.Message{} + receiveChan := make(chan error) + go func() { + receiveChan <- ws.ReadJSON(&resp) + }() + select { + case err := <-receiveChan: + c.Assert(err, qt.IsNil) + case <-time.After(5 * time.Second): + c.Logf("took too long to read response") + c.FailNow() + } + c.Assert(resp.Response, qt.DeepEquals, msg.Params) + ws.Close() + <-errChan // Ensure go routines are cleaned up +} + +func TestProxySocketsControllerConnectionFails(t *testing.T) { + c := qt.New(t) + ctx := context.Background() + + srvController := rpctest.NewServer(rpctest.Echo) + + var connController *websocket.Conn + errChan := make(chan error) + srvJIMM := rpctest.NewServer(func(connClient *websocket.Conn) error { + testTokenGen := testTokenGenerator{} + f := func(context.Context) (rpcproxy.WebsocketConnectionWithMetadata, error) { + var err error + connController = srvController.Dialer.DialWebsocket(c, srvController.URL) + c.Check(err, qt.IsNil) + return rpcproxy.WebsocketConnectionWithMetadata{ + Conn: connController, + ModelName: "TestName", + }, nil + } + auditLogger := func(ale *dbmodel.AuditLogEntry) {} + proxyHelpers := rpcproxy.ProxyHelpers{ + ConnClient: connClient, + TokenGen: &testTokenGen, + ConnectController: f, + AuditLog: auditLogger, + LoginService: &mockLoginService{}, + } + err := rpcproxy.ProxySockets(ctx, proxyHelpers) + c.Check(err, qt.IsNil) + errChan <- err + return err + }) + + defer srvController.Close() + defer srvJIMM.Close() + ws := srvJIMM.Dialer.DialWebsocket(c, srvJIMM.URL) + defer ws.Close() + + p := json.RawMessage(`{"Key":"TestVal"}`) + msg := rpcproxy.Message{RequestID: 1, Type: "TestType", Request: "TestReq", Params: p} + err := ws.WriteJSON(&msg) + c.Assert(err, qt.IsNil) + resp := rpcproxy.Message{} + receiveChan := make(chan error) + go func() { + receiveChan <- ws.ReadJSON(&resp) + }() + select { + case err := <-receiveChan: + c.Assert(err, qt.IsNil) + case <-time.After(5 * time.Second): + c.Logf("took too long to read response") + c.FailNow() + } + c.Assert(resp.Response, qt.DeepEquals, msg.Params) + + // Now close the connection to the controller and ensure the model proxy is cleaned up. + connController.Close() + <-errChan // Ensure go routines are cleaned up +} + +func TestCancelProxySockets(t *testing.T) { + c := qt.New(t) + + ctx, cancel := context.WithCancel(context.Background()) + + srvController := rpctest.NewServer(rpctest.Echo) + + errChan := make(chan error) + srvJIMM := rpctest.NewServer(func(connClient *websocket.Conn) error { + testTokenGen := testTokenGenerator{} + f := func(context.Context) (rpcproxy.WebsocketConnectionWithMetadata, error) { + connController := srvController.Dialer.DialWebsocket(c, srvController.URL) + return rpcproxy.WebsocketConnectionWithMetadata{ + Conn: connController, + ModelName: "TestName", + }, nil + } + auditLogger := func(ale *dbmodel.AuditLogEntry) {} + proxyHelpers := rpcproxy.ProxyHelpers{ + ConnClient: connClient, + TokenGen: &testTokenGen, + ConnectController: f, + AuditLog: auditLogger, + LoginService: &mockLoginService{}, + } + err := rpcproxy.ProxySockets(ctx, proxyHelpers) + c.Check(err, qt.ErrorMatches, "Context cancelled") + errChan <- err + return err + }) + + defer srvController.Close() + defer srvJIMM.Close() + ws := srvJIMM.Dialer.DialWebsocket(c, srvJIMM.URL) + defer ws.Close() + cancel() + <-errChan +} + +func TestProxySocketsAuditLogs(t *testing.T) { + c := qt.New(t) + + ctx := context.Background() + + srvController := rpctest.NewServer(rpctest.Echo) + auditLogs := make([]*dbmodel.AuditLogEntry, 0) + + errChan := make(chan error) + srvJIMM := rpctest.NewServer(func(connClient *websocket.Conn) error { + defer connClient.Close() + testTokenGen := testTokenGenerator{} + f := func(context.Context) (rpcproxy.WebsocketConnectionWithMetadata, error) { + connController := srvController.Dialer.DialWebsocket(c, srvController.URL) + return rpcproxy.WebsocketConnectionWithMetadata{ + Conn: connController, + ModelName: "TestModelName", + }, nil + } + auditLogger := func(ale *dbmodel.AuditLogEntry) { auditLogs = append(auditLogs, ale) } + proxyHelpers := rpcproxy.ProxyHelpers{ + ConnClient: connClient, + TokenGen: &testTokenGen, + ConnectController: f, + AuditLog: auditLogger, + LoginService: &mockLoginService{}, + } + err := rpcproxy.ProxySockets(ctx, proxyHelpers) + c.Check(err, qt.IsNil) + errChan <- err + return err + }) + + defer srvController.Close() + defer srvJIMM.Close() + ws := srvJIMM.Dialer.DialWebsocket(c, srvJIMM.URL) + defer ws.Close() + + p := json.RawMessage(`{"Key":"TestVal"}`) + msg := rpcproxy.Message{RequestID: 1, Type: "TestType", Request: "TestReq", Params: p} + err := ws.WriteJSON(&msg) + c.Assert(err, qt.IsNil) + resp := rpcproxy.Message{} + err = ws.ReadJSON(&resp) + c.Assert(err, qt.IsNil) + ws.Close() + <-errChan // Ensure go routines are cleaned up + c.Assert(auditLogs, qt.HasLen, 2) + expectedEvents := []*dbmodel.AuditLogEntry{{ + ID: auditLogs[0].ID, + Time: auditLogs[0].Time, + Model: "TestModelName", + ConversationId: auditLogs[0].ConversationId, + MessageId: 1, + FacadeName: "TestType", + FacadeMethod: "TestReq", + FacadeVersion: 0, + ObjectId: "", + IdentityTag: "user-testUser", + IsResponse: false, + Params: dbmodel.JSON(p), + Errors: nil, + }, { + ID: auditLogs[1].ID, + Time: auditLogs[1].Time, + Model: "TestModelName", + ConversationId: auditLogs[1].ConversationId, + MessageId: 1, + FacadeName: "", + FacadeMethod: "", + FacadeVersion: 0, + ObjectId: "", + IdentityTag: "user-testUser", + IsResponse: true, + Params: nil, + Errors: auditLogs[1].Errors, + }, + } + c.Assert(auditLogs, qt.DeepEquals, expectedEvents) + +} diff --git a/internal/rpc/apiproxy_test.go b/internal/rpcproxy/rpcproxylogin_test.go similarity index 84% rename from internal/rpc/apiproxy_test.go rename to internal/rpcproxy/rpcproxylogin_test.go index cf4fe793f..393de8c81 100644 --- a/internal/rpc/apiproxy_test.go +++ b/internal/rpcproxy/rpcproxylogin_test.go @@ -1,6 +1,6 @@ -// Copyright 2024 Canonical. +// Copyright 2025 Canonical. -package rpc_test +package rpcproxy_test import ( "context" @@ -19,24 +19,13 @@ import ( "github.com/canonical/jimm/v3/internal/dbmodel" "github.com/canonical/jimm/v3/internal/errors" "github.com/canonical/jimm/v3/internal/openfga" - "github.com/canonical/jimm/v3/internal/rpc" + "github.com/canonical/jimm/v3/internal/rpcproxy" apiparams "github.com/canonical/jimm/v3/pkg/api/params" jimmnames "github.com/canonical/jimm/v3/pkg/names" ) -type message struct { - RequestID uint64 `json:"request-id,omitempty"` - Type string `json:"type,omitempty"` - Version int `json:"version,omitempty"` - ID string `json:"id,omitempty"` - Request string `json:"request,omitempty"` - Params json.RawMessage `json:"params,omitempty"` - Error string `json:"error,omitempty"` - ErrorCode string `json:"error-code,omitempty"` - ErrorInfo map[string]interface{} `json:"error-info,omitempty"` - Response json.RawMessage `json:"response,omitempty"` -} - +// This test verifies that the ProxySockets function +// correctly handles login and authentication. func TestProxySocketsAdminFacade(t *testing.T) { c := qt.New(t) @@ -65,72 +54,72 @@ func TestProxySocketsAdminFacade(t *testing.T) { tests := []struct { about string - messageToSend message + messageToSend rpcproxy.Message authenticateEntityID string - expectedClientResponse *message - expectedControllerMessage *message + expectedClientResponse *rpcproxy.Message + expectedControllerMessage *rpcproxy.Message oauthAuthenticatorError error expectedProxyError string }{{ about: "login device call - client gets response with both user code and verification uri", - messageToSend: message{ + messageToSend: rpcproxy.Message{ RequestID: 1, Type: "Admin", Version: 4, Request: "LoginDevice", }, - expectedClientResponse: &message{ + expectedClientResponse: &rpcproxy.Message{ RequestID: 1, Response: []byte(`{"verification-uri":"http://no-such-uri.canonical.com","user-code":"test-user-code"}`), }, }, { about: "login device call, but the authenticator returns an error", - messageToSend: message{ + messageToSend: rpcproxy.Message{ RequestID: 1, Type: "Admin", Version: 4, Request: "LoginDevice", }, - expectedClientResponse: &message{ + expectedClientResponse: &rpcproxy.Message{ RequestID: 1, Error: "a silly error", }, oauthAuthenticatorError: errors.E("a silly error"), }, { about: "get device session token call - client gets response with a session token", - messageToSend: message{ + messageToSend: rpcproxy.Message{ RequestID: 1, Type: "Admin", Version: 4, Request: "GetDeviceSessionToken", }, - expectedClientResponse: &message{ + expectedClientResponse: &rpcproxy.Message{ RequestID: 1, Response: []byte(`{"session-token":"test session token"}`), }, }, { about: "get device session token call, but the authenticator returns an error", - messageToSend: message{ + messageToSend: rpcproxy.Message{ RequestID: 1, Type: "Admin", Version: 4, Request: "GetDeviceSessionToken", }, - expectedClientResponse: &message{ + expectedClientResponse: &rpcproxy.Message{ RequestID: 1, Error: "a silly error", }, oauthAuthenticatorError: errors.E("a silly error"), }, { about: "login with session token - a login message is sent to the controller", - messageToSend: message{ + messageToSend: rpcproxy.Message{ RequestID: 1, Type: "Admin", Version: 4, Request: "LoginWithSessionToken", Params: []byte(`{"client-id": "test session token"}`), }, - expectedControllerMessage: &message{ + expectedControllerMessage: &rpcproxy.Message{ RequestID: 1, Type: "Admin", Version: 3, @@ -139,14 +128,14 @@ func TestProxySocketsAdminFacade(t *testing.T) { }, }, { about: "login with session token, but authenticator returns an error", - messageToSend: message{ + messageToSend: rpcproxy.Message{ RequestID: 1, Type: "Admin", Version: 4, Request: "LoginWithSessionToken", Params: []byte(`{"client-id": "test session token"}`), }, - expectedClientResponse: &message{ + expectedClientResponse: &rpcproxy.Message{ RequestID: 1, Error: "unauthorized access", ErrorCode: "unauthorized access", @@ -154,14 +143,14 @@ func TestProxySocketsAdminFacade(t *testing.T) { oauthAuthenticatorError: errors.E(errors.CodeUnauthorized), }, { about: "login with client credentials - a login message is sent to the controller", - messageToSend: message{ + messageToSend: rpcproxy.Message{ RequestID: 1, Type: "Admin", Version: 4, Request: "LoginWithClientCredentials", Params: ccData, }, - expectedControllerMessage: &message{ + expectedControllerMessage: &rpcproxy.Message{ RequestID: 1, Type: "Admin", Version: 3, @@ -170,14 +159,14 @@ func TestProxySocketsAdminFacade(t *testing.T) { }, }, { about: "login with client credentials, but authenticator returns an error", - messageToSend: message{ + messageToSend: rpcproxy.Message{ RequestID: 1, Type: "Admin", Version: 4, Request: "LoginWithClientCredentials", Params: ccData, }, - expectedClientResponse: &message{ + expectedClientResponse: &rpcproxy.Message{ RequestID: 1, Error: "unauthorized access", ErrorCode: "unauthorized access", @@ -185,14 +174,14 @@ func TestProxySocketsAdminFacade(t *testing.T) { oauthAuthenticatorError: errors.E(errors.CodeUnauthorized), }, { about: "any other message - gets forwarded directly to the controller", - messageToSend: message{ + messageToSend: rpcproxy.Message{ RequestID: 1, Type: "Client", Version: 7, Request: "AnyMethod", Params: []byte(`{"key":"value"}`), }, - expectedControllerMessage: &message{ + expectedControllerMessage: &rpcproxy.Message{ RequestID: 1, Type: "Client", Version: 7, @@ -201,7 +190,7 @@ func TestProxySocketsAdminFacade(t *testing.T) { }, }, { about: "login with session cookie - a login message is sent to the controller", - messageToSend: message{ + messageToSend: rpcproxy.Message{ RequestID: 1, Type: "Admin", Version: 4, @@ -209,7 +198,7 @@ func TestProxySocketsAdminFacade(t *testing.T) { Params: ccData, }, authenticateEntityID: "alice@wonderland.io", - expectedControllerMessage: &message{ + expectedControllerMessage: &rpcproxy.Message{ RequestID: 1, Type: "Admin", Version: 3, @@ -218,14 +207,14 @@ func TestProxySocketsAdminFacade(t *testing.T) { }, }, { about: "login with session cookie - but there was no identity id in the cookie", - messageToSend: message{ + messageToSend: rpcproxy.Message{ RequestID: 1, Type: "Admin", Version: 4, Request: "LoginWithSessionCookie", Params: ccData, }, - expectedClientResponse: &message{ + expectedClientResponse: &rpcproxy.Message{ RequestID: 1, Error: "unauthorized access", ErrorCode: "unauthorized access", @@ -233,7 +222,7 @@ func TestProxySocketsAdminFacade(t *testing.T) { oauthAuthenticatorError: errors.E(errors.CodeUnauthorized), }, { about: "connection to controller fails", - expectedClientResponse: &message{ + expectedClientResponse: &rpcproxy.Message{ Error: "controller connection error", }, expectedProxyError: "failed to connect to controller: controller connection error", @@ -256,14 +245,14 @@ func TestProxySocketsAdminFacade(t *testing.T) { err: test.oauthAuthenticatorError, } - helpers := rpc.ProxyHelpers{ + helpers := rpcproxy.ProxyHelpers{ ConnClient: clientWebsocket, TokenGen: &mockTokenGenerator{}, - ConnectController: func(ctx context.Context) (rpc.WebsocketConnectionWithMetadata, error) { + ConnectController: func(ctx context.Context) (rpcproxy.WebsocketConnectionWithMetadata, error) { if proxyError { - return rpc.WebsocketConnectionWithMetadata{}, goerr.New("controller connection error") + return rpcproxy.WebsocketConnectionWithMetadata{}, goerr.New("controller connection error") } - return rpc.WebsocketConnectionWithMetadata{ + return rpcproxy.WebsocketConnectionWithMetadata{ Conn: controllerWebsocket, ModelName: "test model", ControllerUUID: uuid.NewString(), @@ -277,7 +266,7 @@ func TestProxySocketsAdminFacade(t *testing.T) { wg.Add(1) go func() { defer wg.Done() - err = rpc.ProxySockets(ctx, helpers) + err = rpcproxy.ProxySockets(ctx, helpers) if proxyError { c.Assert(err, qt.ErrorMatches, test.expectedProxyError) } else { diff --git a/internal/rpc/streamproxy.go b/internal/streamproxy/streamproxy.go similarity index 73% rename from internal/rpc/streamproxy.go rename to internal/streamproxy/streamproxy.go index 5da96206f..1fbd38732 100644 --- a/internal/rpc/streamproxy.go +++ b/internal/streamproxy/streamproxy.go @@ -1,10 +1,12 @@ -// Copyright 2024 Canonical. +// Copyright 2025 Canonical. -package rpc +package streamproxy import ( "context" + "encoding/json" + "github.com/gorilla/websocket" "github.com/juju/juju/api/base" "github.com/juju/zaputil/zapctx" "go.uber.org/zap" @@ -46,3 +48,14 @@ func proxy(src base.Stream, dst base.Stream) error { } } } + +func unexpectedReadError(err error) bool { + if websocket.IsUnexpectedCloseError(err, + websocket.CloseNormalClosure, + websocket.CloseNoStatusReceived, + websocket.CloseAbnormalClosure) { + return true + } + _, unmarshalError := err.(*json.InvalidUnmarshalError) + return unmarshalError +} diff --git a/internal/rpc/streamproxy_test.go b/internal/streamproxy/streamproxy_test.go similarity index 61% rename from internal/rpc/streamproxy_test.go rename to internal/streamproxy/streamproxy_test.go index 2385d2109..890b2af4d 100644 --- a/internal/rpc/streamproxy_test.go +++ b/internal/streamproxy/streamproxy_test.go @@ -1,5 +1,6 @@ -// Copyright 2024 Canonical. -package rpc_test +// Copyright 2025 Canonical. + +package streamproxy_test import ( "context" @@ -11,7 +12,8 @@ import ( qt "github.com/frankban/quicktest" "github.com/gorilla/websocket" - "github.com/canonical/jimm/v3/internal/rpc" + "github.com/canonical/jimm/v3/internal/streamproxy" + "github.com/canonical/jimm/v3/internal/testutils/rpctest" ) func echoSingleMessage(c *websocket.Conn) error { @@ -54,18 +56,16 @@ func TestStreamProxy(t *testing.T) { ctx := context.Background() doneChan := make(chan error) - srvController := newServer(echoSingleMessage) - srvJIMM := newServer(func(connClient *websocket.Conn) error { - connController, err := srvController.dialer.DialWebsocket(ctx, srvController.URL, nil) - c.Assert(err, qt.IsNil) - rpc.ProxyStreams(ctx, connClient, connController) + srvController := rpctest.NewServer(echoSingleMessage) + srvJIMM := rpctest.NewServer(func(connClient *websocket.Conn) error { + connController := srvController.Dialer.DialWebsocket(c, srvController.URL) + streamproxy.ProxyStreams(ctx, connClient, connController) doneChan <- nil return nil }) defer srvController.Close() defer srvJIMM.Close() - ws, err := srvJIMM.dialer.DialWebsocket(ctx, srvJIMM.URL, nil) - c.Assert(err, qt.IsNil) + ws := srvJIMM.Dialer.DialWebsocket(c, srvJIMM.URL) defer ws.Close() verifyEcho(c, ws, "") @@ -79,18 +79,16 @@ func TestStreamProxyStoppedController(t *testing.T) { ctx := context.Background() doneChan := make(chan error) - srvController := newServer(func(c *websocket.Conn) error { return errors.New("stopped") }) - srvJIMM := newServer(func(connClient *websocket.Conn) error { - connController, err := srvController.dialer.DialWebsocket(ctx, srvController.URL, nil) - c.Assert(err, qt.IsNil) - rpc.ProxyStreams(ctx, connClient, connController) + srvController := rpctest.NewServer(func(c *websocket.Conn) error { return errors.New("stopped") }) + srvJIMM := rpctest.NewServer(func(connClient *websocket.Conn) error { + connController := srvController.Dialer.DialWebsocket(c, srvController.URL) + streamproxy.ProxyStreams(ctx, connClient, connController) doneChan <- nil return nil }) defer srvController.Close() defer srvJIMM.Close() - ws, err := srvJIMM.dialer.DialWebsocket(ctx, srvJIMM.URL, nil) - c.Assert(err, qt.IsNil) + ws := srvJIMM.Dialer.DialWebsocket(c, srvJIMM.URL) defer ws.Close() verifyEcho(c, ws, ".*abnormal closure.*") @@ -104,18 +102,16 @@ func TestStreamProxyStoppedMidwayController(t *testing.T) { ctx := context.Background() doneChan := make(chan error) - srvController := newServer(echoSingleMessage) - srvJIMM := newServer(func(connClient *websocket.Conn) error { - connController, err := srvController.dialer.DialWebsocket(ctx, srvController.URL, nil) - c.Assert(err, qt.IsNil) - rpc.ProxyStreams(ctx, connClient, connController) + srvController := rpctest.NewServer(echoSingleMessage) + srvJIMM := rpctest.NewServer(func(connClient *websocket.Conn) error { + connController := srvController.Dialer.DialWebsocket(c, srvController.URL) + streamproxy.ProxyStreams(ctx, connClient, connController) doneChan <- nil return nil }) defer srvController.Close() defer srvJIMM.Close() - ws, err := srvJIMM.dialer.DialWebsocket(ctx, srvJIMM.URL, nil) - c.Assert(err, qt.IsNil) + ws := srvJIMM.Dialer.DialWebsocket(c, srvJIMM.URL) defer ws.Close() verifyEcho(c, ws, "") diff --git a/internal/testutils/rpctest/server.go b/internal/testutils/rpctest/server.go new file mode 100644 index 000000000..92ec1b7c4 --- /dev/null +++ b/internal/testutils/rpctest/server.go @@ -0,0 +1,96 @@ +// Copyright 2025 Canonical. + +package rpctest + +import ( + "context" + "crypto/tls" + "crypto/x509" + "net/http" + "net/http/httptest" + "strings" + "time" + + qt "github.com/frankban/quicktest" + "github.com/gorilla/websocket" +) + +type testDialer struct { + tlsConfig *tls.Config +} + +// Dial establishes a new client RPC connection to the given URL. +func (d *testDialer) DialWebsocket(c *qt.C, url string) *websocket.Conn { + dialer := websocket.Dialer{ + TLSClientConfig: d.tlsConfig, + } + conn, resp, err := dialer.DialContext(context.Background(), url, nil) + c.Assert(err, qt.IsNil) + defer resp.Body.Close() + return conn +} + +type Server struct { + *httptest.Server + + URL string + Dialer *testDialer +} + +func NewServer(f func(*websocket.Conn) error) *Server { + var srv Server + srv.Server = httptest.NewTLSServer(HandleWS(f)) + srv.URL = "ws" + strings.TrimPrefix(srv.Server.URL, "http") + cp := x509.NewCertPool() + cp.AddCert(srv.Certificate()) + srv.Dialer = &testDialer{ + tlsConfig: &tls.Config{ + RootCAs: cp, + MinVersion: tls.VersionTLS12, + }, + } + return &srv +} + +func HandleWS(f func(*websocket.Conn) error) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + var u websocket.Upgrader + c, err := u.Upgrade(w, req, nil) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + defer c.Close() + err = f(c) + var cm []byte + closeError, isCloseError := err.(*websocket.CloseError) + switch { + case err == nil: + cm = websocket.FormatCloseMessage(websocket.CloseNormalClosure, "") + case isCloseError: + cm = websocket.FormatCloseMessage(closeError.Code, closeError.Text) + default: + cm = websocket.FormatCloseMessage(websocket.CloseInternalServerErr, err.Error()) + } + _ = c.WriteControl(websocket.CloseMessage, cm, time.Time{}) + + }) +} + +func Echo(c *websocket.Conn) error { + for { + msg := make(map[string]interface{}) + if err := c.ReadJSON(&msg); err != nil { + return err + } + delete(msg, "type") + delete(msg, "version") + delete(msg, "id") + delete(msg, "request") + msg["response"] = msg["params"] + delete(msg, "params") + if err := c.WriteJSON(msg); err != nil { + return err + } + } +}