Skip to content

Commit

Permalink
feat: add keyManager handler to model proxy
Browse files Browse the repository at this point in the history
The model proxy will intercept calls to the keyManager facade and persist user keys in JIMM rather than passing these calls along to the Juju controller. This is being done in order to support the SSH proxy efforts.
Ideally, in Juju 4 these methods would be done on the controller api and all this logic would move into the jujuapi package.
  • Loading branch information
kian99 committed Jan 15, 2025
1 parent 217bd0f commit 44c5156
Show file tree
Hide file tree
Showing 8 changed files with 524 additions and 1 deletion.
6 changes: 6 additions & 0 deletions internal/jimm/jimm.go
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,12 @@ func (j *JIMM) ServiceAccountManager() ServiceAccountManager {
return j.serviceAccountManager
}

// SSHKeyManager returns a manager that enables operations
// related to ssh keys.
func (j *JIMM) SSHKeyManager() SSHKeyManager {
return j.sshKeyManager
}

type permission struct {
resource string
relation string
Expand Down
1 change: 1 addition & 0 deletions internal/jujuapi/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ func (s apiProxier) ServeWS(ctx context.Context, clientConn *websocket.Conn) {
AuditLog: auditLogger,
LoginService: s.jimm.LoginManager(),
AuthenticatedIdentityID: auth.SessionIdentityFromContext(ctx),
SSHKeyManager: s.jimm.SSHKeyManager(),
}
if err := rpcproxy.ProxySockets(ctx, proxyHelpers); err != nil {
zapctx.Error(ctx, "failed to start jimm model proxy", zap.Error(err))
Expand Down
5 changes: 4 additions & 1 deletion internal/rpcproxy/export_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,7 @@

package rpcproxy

type Message = message
type (
Message = message
KeyManagerFacade = keyManagerFacade
)
97 changes: 97 additions & 0 deletions internal/rpcproxy/rpcproxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (

"github.com/canonical/jimm/v3/internal/dbmodel"
"github.com/canonical/jimm/v3/internal/errors"
"github.com/canonical/jimm/v3/internal/jimm/sshkeys"
"github.com/canonical/jimm/v3/internal/openfga"
"github.com/canonical/jimm/v3/internal/servermon"
"github.com/canonical/jimm/v3/internal/utils"
Expand All @@ -32,6 +33,18 @@ const (
accessRequiredErrorCode = "access required"
)

// SSHKeyManager is an interface for managing SSH keys.
type SSHKeyManager interface {
// AddUserPublicKey saves a user's public key.
AddUserPublicKey(ctx context.Context, user *openfga.User, publicKey sshkeys.PublicKey) error
// ListUserPublicKeys lists a user's public keys.
ListUserPublicKeys(ctx context.Context, user *openfga.User) ([]sshkeys.PublicKey, error)
// RemoveUserKeyByComment removes a user's public key(s) by the key comment.
RemoveUserKeyByComment(ctx context.Context, user *openfga.User, comment string) error
// RemoveUserKeyByFingerprint removes a user's public key(s) by the key fingerprint.
RemoveUserKeyByFingerprint(ctx context.Context, user *openfga.User, fingerprint string) error
}

// TokenGenerator authenticates a user and generates a JWT token.
type TokenGenerator interface {
// MakeLoginToken returns a JWT containing claims about user's access
Expand Down Expand Up @@ -77,6 +90,7 @@ type LoginService interface {
// connection to a model.
type ProxyHelpers struct {
ConnClient WebsocketConnection
SSHKeyManager SSHKeyManager
TokenGen TokenGenerator
ConnectController func(context.Context) (WebsocketConnectionWithMetadata, error)
AuditLog func(*dbmodel.AuditLogEntry)
Expand All @@ -101,6 +115,10 @@ func ProxySockets(ctx context.Context, helpers ProxyHelpers) error {
zapctx.Error(ctx, "Missing login service function")
return errors.E(op, "Missing login service function")
}
if helpers.SSHKeyManager == nil {
zapctx.Error(ctx, "Missing ssh key manager function")
return errors.E(op, "Missing ssh key manager function")
}
errChan := make(chan error, 2)
msgInFlight := inflightMsgs{messages: make(map[uint64]*message)}
client := writeLockConn{conn: helpers.ConnClient}
Expand All @@ -113,6 +131,7 @@ func ProxySockets(ctx context.Context, helpers ProxyHelpers) error {
tokenGen: helpers.TokenGen,
auditLog: helpers.AuditLog,
conversationId: utils.NewConversationID(),
sshKeyManager: helpers.SSHKeyManager,
loginService: helpers.LoginService,
authenticatedIdentityID: helpers.AuthenticatedIdentityID,
},
Expand Down Expand Up @@ -247,6 +266,7 @@ type modelProxy struct {
msgs *inflightMsgs
auditLog func(*dbmodel.AuditLogEntry)
tokenGen TokenGenerator
sshKeyManager SSHKeyManager
loginService LoginService
modelName string
conversationId string
Expand Down Expand Up @@ -331,6 +351,7 @@ func unexpectedReadError(err error) bool {
// clientProxy proxies messages from client->controller.
type clientProxy struct {
modelProxy
user *openfga.User
wg sync.WaitGroup
errChan chan error
createControllerConn func(context.Context) (WebsocketConnectionWithMetadata, error)
Expand Down Expand Up @@ -384,6 +405,19 @@ func (p *clientProxy) start(ctx context.Context) error {
p.msgs.addLoginMessage(toController)
}
}
// This is a special case for the KeyManager facade. We handle it here
// because it is a model level facade. In Juju 4 we want to move this
// to a controller level facade and place the logic in jujuapi.
if msg.Type == "KeyManager" {
zapctx.Debug(ctx, "handling a KeyManager facade call")
toClient, err := p.handleKeyManagerFacade(ctx, msg)
if err != nil {
p.sendError(p.src, msg, err)
continue
}
p.src.sendMessage(nil, toClient)
continue
}
p.msgs.addMessage(msg)
zapctx.Debug(ctx, "Writing to controller")
if err := p.dst.writeJson(msg); err != nil {
Expand Down Expand Up @@ -683,6 +717,7 @@ func (p *clientProxy) handleAdminFacade(ctx context.Context, msg *message) (clie
if err != nil {
return errorFnc(err)
}
p.user = user

return controllerLoginMessageFnc(user)
case "LoginWithClientCredentials":
Expand All @@ -695,13 +730,15 @@ func (p *clientProxy) handleAdminFacade(ctx context.Context, msg *message) (clie
if err != nil {
return errorFnc(err)
}
p.user = user

return controllerLoginMessageFnc(user)
case "LoginWithSessionCookie":
user, err := p.loginService.LoginWithSessionCookie(ctx, p.modelProxy.authenticatedIdentityID)
if err != nil {
return errorFnc(err)
}
p.user = user

return controllerLoginMessageFnc(user)
case "Login":
Expand All @@ -710,3 +747,63 @@ func (p *clientProxy) handleAdminFacade(ctx context.Context, msg *message) (clie
return nil, nil, nil
}
}

// handleKeyManagerFacade processes the key manager facade call.
func (p *clientProxy) handleKeyManagerFacade(ctx context.Context, msg *message) (clientResponse *message, err error) {
if p.user == nil {
return nil, errors.E("user not authenticated")
}
clientRespF := func(data any) (*message, error) {
resp, err := json.Marshal(data)
if err != nil {
return nil, err
}
msg.Response = resp
return msg, nil
}
keyManager := keyManagerFacade{SSHKeyManager: p.sshKeyManager, user: p.user}

switch msg.Request {
case "ListKeys":
var request params.ListSSHKeys
err := json.Unmarshal(msg.Params, &request)
if err != nil {
return nil, err
}
res, err := keyManager.ListKeys(ctx, request)
if err != nil {
return nil, err
}
return clientRespF(res)

case "AddKeys":
var request params.ModifyUserSSHKeys
err := json.Unmarshal(msg.Params, &request)
if err != nil {
return nil, err
}
res, err := keyManager.AddKeys(ctx, request)
if err != nil {
return nil, err
}
return clientRespF(res)

case "DeleteKeys":
var request params.ModifyUserSSHKeys
err := json.Unmarshal(msg.Params, &request)
if err != nil {
return nil, err
}
res, err := keyManager.DeleteKeys(ctx, request)
if err != nil {
return nil, err
}
return clientRespF(res)

case "ImportKeys":
return nil, errors.E("ImportKeys not implemented", errors.CodeNotImplemented)

default:
return nil, errors.E("unknown key manager request")
}
}
144 changes: 144 additions & 0 deletions internal/rpcproxy/rpcproxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,15 @@ import (

qt "github.com/frankban/quicktest"
"github.com/gorilla/websocket"
jujuparams "github.com/juju/juju/rpc/params"
"github.com/juju/names/v5"
"github.com/juju/utils/v3/ssh"

"github.com/canonical/jimm/v3/internal/dbmodel"
"github.com/canonical/jimm/v3/internal/jimm/sshkeys"
"github.com/canonical/jimm/v3/internal/openfga"
"github.com/canonical/jimm/v3/internal/rpcproxy"
"github.com/canonical/jimm/v3/internal/testutils/jimmtest/mocks"
"github.com/canonical/jimm/v3/internal/testutils/rpctest"
)

Expand Down Expand Up @@ -58,6 +62,7 @@ func TestProxySockets(t *testing.T) {
ConnectController: f,
AuditLog: auditLogger,
LoginService: &mockLoginService{},
SSHKeyManager: &mocks.SSHKeyManager{},
}
err := rpcproxy.ProxySockets(ctx, proxyHelpers)
c.Check(err, qt.IsNil)
Expand Down Expand Up @@ -117,6 +122,7 @@ func TestProxySocketsControllerConnectionFails(t *testing.T) {
ConnectController: f,
AuditLog: auditLogger,
LoginService: &mockLoginService{},
SSHKeyManager: &mocks.SSHKeyManager{},
}
err := rpcproxy.ProxySockets(ctx, proxyHelpers)
c.Check(err, qt.IsNil)
Expand Down Expand Up @@ -176,6 +182,7 @@ func TestCancelProxySockets(t *testing.T) {
ConnectController: f,
AuditLog: auditLogger,
LoginService: &mockLoginService{},
SSHKeyManager: &mocks.SSHKeyManager{},
}
err := rpcproxy.ProxySockets(ctx, proxyHelpers)
c.Check(err, qt.ErrorMatches, "Context cancelled")
Expand Down Expand Up @@ -217,6 +224,7 @@ func TestProxySocketsAuditLogs(t *testing.T) {
ConnectController: f,
AuditLog: auditLogger,
LoginService: &mockLoginService{},
SSHKeyManager: &mocks.SSHKeyManager{},
}
err := rpcproxy.ProxySockets(ctx, proxyHelpers)
c.Check(err, qt.IsNil)
Expand Down Expand Up @@ -272,3 +280,139 @@ func TestProxySocketsAuditLogs(t *testing.T) {
c.Assert(auditLogs, qt.DeepEquals, expectedEvents)

}

func TestProxySocketsSSHKeys(t *testing.T) {
c := qt.New(t)

ctx := context.Background()
sshFacadeChan := make(chan (string), 1)

srvController := rpctest.NewServer(rpctest.Echo)

errChan := make(chan error)
srvJIMM := rpctest.NewServer(func(connClient *websocket.Conn) error {
defer connClient.Close()
testTokenGen := testTokenGenerator{}
connectControllerF := func(context.Context) (rpcproxy.WebsocketConnectionWithMetadata, error) {
connController := srvController.Dialer.DialWebsocket(c, srvController.URL)
return rpcproxy.WebsocketConnectionWithMetadata{
Conn: connController,
ModelName: "TestModelName",
}, nil
}
proxyHelpers := rpcproxy.ProxyHelpers{
ConnClient: connClient,
TokenGen: &testTokenGen,
ConnectController: connectControllerF,
AuditLog: func(ale *dbmodel.AuditLogEntry) {},
LoginService: &mockLoginService{
email: "alice@canonical.com",
},
SSHKeyManager: &mocks.SSHKeyManager{
AddUserPublicKey_: func(ctx context.Context, user *openfga.User, publicKey sshkeys.PublicKey) error {
sshFacadeChan <- "add-keys"
return nil
},
ListUserPublicKeys_: func(ctx context.Context, user *openfga.User) ([]sshkeys.PublicKey, error) {
sshFacadeChan <- "list-keys"
return nil, nil
},
RemoveUserKeyByComment_: func(ctx context.Context, user *openfga.User, comment string) error {
sshFacadeChan <- "remove-keys-comment"
return nil
},
RemoveUserKeyByFingerprint_: func(ctx context.Context, user *openfga.User, fingerprint string) error {
sshFacadeChan <- "remove-keys-fingerprint"
return nil
},
},
}
err := rpcproxy.ProxySockets(ctx, proxyHelpers)
c.Check(err, qt.IsNil)
errChan <- err
return err
})

defer srvController.Close()
defer srvJIMM.Close()
ws := srvJIMM.Dialer.DialWebsocket(c, srvJIMM.URL)
defer ws.Close()

// Perform login
p := json.RawMessage(`{"Key":"TestVal"}`)
msg := rpcproxy.Message{RequestID: 1, Type: "Admin", Request: "LoginWithSessionToken", Params: p} // #nosec G115 accept integer conversion
err := ws.WriteJSON(&msg)
c.Assert(err, qt.IsNil)
resp := rpcproxy.Message{}
err = ws.ReadJSON(&resp)
c.Assert(err, qt.IsNil)
c.Assert(resp.Error, qt.Equals, "")

// Run sub-tests for all SSH Key methods
tests := []struct {
name string
request string
params []byte
expectedChanResult string
expectedErr string
}{
{
name: "Add key method",
request: "AddKeys",
expectedChanResult: "add-keys",
params: mustMarshal(jujuparams.ModifyUserSSHKeys{Keys: []string{"type key comment"}}),
},
{
name: "List keys method",
request: "ListKeys",
expectedChanResult: "list-keys",
params: mustMarshal(jujuparams.ListSSHKeys{Mode: ssh.Fingerprints}),
},
{
name: "Delete keys by comment",
request: "DeleteKeys",
expectedChanResult: "remove-keys-comment",
params: mustMarshal(jujuparams.ModifyUserSSHKeys{Keys: []string{"comment"}}),
},
{
name: "Delete keys by fingerprint",
request: "DeleteKeys",
expectedChanResult: "remove-keys-fingerprint",
params: mustMarshal(jujuparams.ModifyUserSSHKeys{Keys: []string{"79:fc:60:93:ec:ce:42:fe:15:61:f2:fb:d6:22:43:6e"}}),
},
}

for i, test := range tests {
c.Run(test.name, func(c *qt.C) {
msg := rpcproxy.Message{RequestID: uint64(i + 1), Type: "KeyManager", Request: test.request, Params: test.params} // #nosec G115 accept integer conversion
err := ws.WriteJSON(&msg)
c.Assert(err, qt.IsNil)

resp := rpcproxy.Message{}
err = ws.ReadJSON(&resp)
c.Assert(err, qt.IsNil)
if test.expectedErr == "" {
c.Assert(resp.Error, qt.Equals, "")
} else {
c.Assert(err, qt.Matches, test.expectedErr)
}

select {
case res := <-sshFacadeChan:
c.Assert(res, qt.Equals, test.expectedChanResult)
case <-time.After(100 * time.Millisecond):
c.Error("Expected SSH method was not called")
}
})
}
ws.Close()
<-errChan // Ensure go routines are cleaned up
}

func mustMarshal(data any) []byte {
out, err := json.Marshal(data)
if err != nil {
panic(err)
}
return out
}
Loading

0 comments on commit 44c5156

Please sign in to comment.