diff --git a/go.mod b/go.mod index a56f08639..b725ae5ce 100644 --- a/go.mod +++ b/go.mod @@ -30,9 +30,11 @@ require ( github.com/coreos/go-oidc/v3 v3.11.0 github.com/dustinkirkland/golang-petname v0.0.0-20231002161417-6a283f1aaaf2 github.com/frankban/quicktest v1.14.6 + github.com/gliderlabs/ssh v0.3.8 github.com/go-chi/chi/v5 v5.0.12 github.com/go-chi/render v1.0.2 github.com/go-macaroon-bakery/macaroon-bakery/v3 v3.0.1 + github.com/golang-migrate/migrate/v4 v4.17.1 github.com/google/go-cmp v0.6.0 github.com/google/uuid v1.6.0 github.com/gorilla/sessions v1.2.1 @@ -314,7 +316,7 @@ require ( go.uber.org/atomic v1.11.0 // indirect go.uber.org/mock v0.5.0 // indirect go.uber.org/multierr v1.11.0 // indirect - golang.org/x/crypto v0.31.0 // indirect + golang.org/x/crypto v0.31.0 golang.org/x/exp v0.0.0-20240904232852-e7e105dedf7e // indirect golang.org/x/net v0.30.0 // indirect golang.org/x/sys v0.28.0 // indirect @@ -346,9 +348,8 @@ require ( sigs.k8s.io/structured-merge-diff/v4 v4.4.1 // indirect ) -require github.com/golang-migrate/migrate/v4 v4.17.1 - require ( + github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be // indirect github.com/klauspost/compress v1.17.7 // indirect github.com/moby/patternmatcher v0.6.0 // indirect github.com/moby/sys/sequential v0.6.0 // indirect diff --git a/go.sum b/go.sum index 4b98f0c71..a340e473e 100644 --- a/go.sum +++ b/go.sum @@ -91,6 +91,8 @@ github.com/adrg/xdg v0.3.3 h1:s/tV7MdqQnzB1nKY8aqHvAMD+uCiuEDzVB5HLRY849U= github.com/adrg/xdg v0.3.3/go.mod h1:61xAR2VZcggl2St4O9ohF5qCKe08+JDmE4VNzPFQvOQ= github.com/ajg/form v1.5.1 h1:t9c7v8JUKu/XxOGBU0yjNpaMloxGEJhUkqFRq0ibGeU= github.com/ajg/form v1.5.1/go.mod h1:uL1WgH+h2mgNtvBq0339dVnzXdBETtL2LeUXaIv25UY= +github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be h1:9AeTilPcZAjCFIImctFaOjnTIavg87rW78vTPkQqLI8= +github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be/go.mod h1:ySMOLuWl6zY27l47sB3qLNK6tF2fkHG55UZxx8oIVo4= github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY= github.com/antlr4-go/antlr/v4 v4.13.1 h1:SqQKkuVZ+zWkMMNkjy5FZe5mr5WURWnlpmOuzYWrPrQ= github.com/antlr4-go/antlr/v4 v4.13.1/go.mod h1:GKmUxMtwp6ZgGwZSva4eWPC5mS6vUAmOABFgjdkM7Nw= @@ -253,6 +255,8 @@ github.com/gdamore/tcell/v2 v2.5.1/go.mod h1:wSkrPaXoiIWZqW/g7Px4xc79di6FTcpB8tv github.com/getkin/kin-openapi v0.125.0 h1:jyQCyf2qXS1qvs2U00xQzkGCqYPhEhZDmSmVt65fXno= github.com/getkin/kin-openapi v0.125.0/go.mod h1:wb1aSZA/iWmorQP9KTAS/phLj/t17B5jT7+fS8ed9NM= github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= +github.com/gliderlabs/ssh v0.3.8 h1:a4YXD1V7xMF9g5nTkdfnja3Sxy1PVDCj1Zg4Wb8vY6c= +github.com/gliderlabs/ssh v0.3.8/go.mod h1:xYoytBv1sV0aL3CavoDuJIQNURXkkfPA/wxQ1pL1fAU= github.com/go-chi/chi/v5 v5.0.12 h1:9euLV5sTrTNTRUU9POmDUvfxyj6LAABLUcEWO+JJb4s= github.com/go-chi/chi/v5 v5.0.12/go.mod h1:DslCQbL2OYiznFReuXYUmQ2hGd1aDpCnlMNITLSKoi8= github.com/go-chi/render v1.0.2 h1:4ER/udB0+fMWB2Jlf15RV3F4A2FDuYi/9f+lFttR/Lg= diff --git a/internal/ssh/export_test.go b/internal/ssh/export_test.go new file mode 100644 index 000000000..b83b094cb --- /dev/null +++ b/internal/ssh/export_test.go @@ -0,0 +1,5 @@ +// Copyright 2025 Canonical. + +package ssh + +type ForwardMessage = forwardMessage diff --git a/internal/ssh/ssh.go b/internal/ssh/ssh.go new file mode 100644 index 000000000..5a554be58 --- /dev/null +++ b/internal/ssh/ssh.go @@ -0,0 +1,144 @@ +// Copyright 2025 Canonical. + +package ssh + +import ( + "context" + "fmt" + "io" + "net" + + "github.com/gliderlabs/ssh" + "github.com/juju/zaputil/zapctx" + "go.uber.org/zap" + gossh "golang.org/x/crypto/ssh" + + "github.com/canonical/jimm/v3/internal/openfga" +) + +// juju_ssh_default_port is the default port we expect the juju controllers to respond on. +const juju_ssh_default_port = 17022 + +// Resolver is the interface with the methods needed by the ssh jump server to route request. +type Resolver interface { + // AddrFromModelUUID is the method to resolve the address of the controller to contact given the model UUID. + AddrFromModelUUID(ctx context.Context, user openfga.User, modelUUID string) (string, error) +} + +// fowardMessage is the struct holding the information about the jump message received by the ssh client. +type forwardMessage struct { + DestAddr string + DestPort uint32 + SrcAddr string + SrcPort uint32 +} + +// Server is the custom struct to embed the gliderlabs.ssh server and a resolver. +type Server struct { + *ssh.Server + + resolver Resolver +} + +// NewJumpSSHServer creates the jump server struct. +func NewJumpSSHServer(ctx context.Context, port int, resolver Resolver) (Server, error) { + zapctx.Info(ctx, "NewSSHServer") + + if resolver == nil { + return Server{}, fmt.Errorf("Cannot create JumpSSHServer with a nil resolver.") + } + server := Server{ + Server: &ssh.Server{ + Addr: fmt.Sprintf(":%d", port), + ChannelHandlers: map[string]ssh.ChannelHandler{ + "direct-tcpip": directTCPIPHandler(resolver), + }, + PublicKeyHandler: func(ctx ssh.Context, key ssh.PublicKey) bool { + return true + }, + }, + resolver: resolver, + } + + return server, nil +} + +func directTCPIPHandler(resolver Resolver) func(srv *ssh.Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx ssh.Context) { + return func(srv *ssh.Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx ssh.Context) { + d := forwardMessage{} + + k := newChan.ExtraData() + + if err := gossh.Unmarshal(k, &d); err != nil { + rejectConnectionAndLogError(ctx, newChan, "Failed to parse channel data", err) + return + } + if d.DestPort == 0 { + d.DestPort = juju_ssh_default_port + } + addr, err := resolver.AddrFromModelUUID(ctx, openfga.User{}, d.DestAddr) + if err != nil { + rejectConnectionAndLogError(ctx, newChan, "Failed to resolve address from model uuid", err) + return + } + dest := net.JoinHostPort(addr, fmt.Sprint(d.DestPort)) + // this is temporary. The way we dial to the controller will heavily change. + client, err := gossh.Dial("tcp", dest, &gossh.ClientConfig{ + //nolint:gosec // this will be removed once we handle hostkeys + HostKeyCallback: gossh.InsecureIgnoreHostKey(), + Auth: []gossh.AuthMethod{ + gossh.PasswordCallback(func() (secret string, err error) { + return "jwt", nil + }), + }, + }) + if err != nil { + rejectConnectionAndLogError(ctx, newChan, fmt.Sprintf("Failed to connect to %s: %v", dest, err), err) + return + } + + dstChan, reqs, err := client.OpenChannel("direct-tcpip", gossh.Marshal(d)) + if err != nil { + rejectConnectionAndLogError(ctx, newChan, "Failed to open destination channel", err) + return + } + // gossh.Request are requests sent outside of the normal stream of data (ex. pty-req for an interactive session). + // Since we only need the raw data to redirect, we can discard them. + go gossh.DiscardRequests(reqs) + + srcDest, reqs, err := newChan.Accept() + if err != nil { + dstChan.Close() + return + } + // gossh.Request are requests sent outside of the normal stream of data (ex. pty-req for an interactive session). + // Since we only need the raw data to redirect, we can discard them. + go gossh.DiscardRequests(reqs) + + go func() { + defer srcDest.Close() + defer dstChan.Close() + _, err := io.Copy(srcDest, dstChan) + if err != nil { + rejectConnectionAndLogError(ctx, newChan, "Failed to copy data from src to dts", err) + } + }() + go func() { + defer srcDest.Close() + defer dstChan.Close() + _, err := io.Copy(dstChan, srcDest) + if err != nil { + rejectConnectionAndLogError(ctx, newChan, "Failed to copy data from dst to src", err) + } + }() + zapctx.Info(ctx, fmt.Sprintf("Proxying connection from %s:%d to %s:%d \n", d.SrcAddr, d.SrcPort, d.DestAddr, d.DestPort)) + } +} + +func rejectConnectionAndLogError(ctx context.Context, newChan gossh.NewChannel, msg string, err error) { + zapctx.Error(ctx, msg, zap.Error(err)) + err = newChan.Reject(gossh.ConnectionFailed, msg) + if err != nil { + zapctx.Error(ctx, msg, zap.Error(err)) + } +} diff --git a/internal/ssh/ssh_test.go b/internal/ssh/ssh_test.go new file mode 100644 index 000000000..687afcbaa --- /dev/null +++ b/internal/ssh/ssh_test.go @@ -0,0 +1,171 @@ +// Copyright 2025 Canonical. + +package ssh_test + +import ( + "context" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "encoding/pem" + "fmt" + "strconv" + "strings" + "testing" + "time" + + qt "github.com/frankban/quicktest" + "github.com/frankban/quicktest/qtsuite" + gliderssh "github.com/gliderlabs/ssh" + gossh "golang.org/x/crypto/ssh" + + "github.com/canonical/jimm/v3/internal/openfga" + "github.com/canonical/jimm/v3/internal/ssh" + "github.com/canonical/jimm/v3/internal/testutils/jimmtest" +) + +type resolver struct{} + +func (r resolver) AddrFromModelUUID(ctx context.Context, user openfga.User, modelName string) (string, error) { + return "", nil +} + +type sshSuite struct { + destinationJujuSSHServer gliderssh.Server + destinationServerPort int + jumpSSHServer ssh.Server + jumpServerPort int + privateKey gossh.Signer + testInDestinationServerF func(fm ssh.ForwardMessage) + received chan bool +} + +func (s *sshSuite) Init(c *qt.C) { + s.received = make(chan bool) + port, err := jimmtest.GetFreePort() + c.Assert(err, qt.IsNil) + s.destinationServerPort = port + s.destinationJujuSSHServer = gliderssh.Server{ + Addr: fmt.Sprintf(":%d", port), + ChannelHandlers: map[string]gliderssh.ChannelHandler{ + "direct-tcpip": func(srv *gliderssh.Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx gliderssh.Context) { + d := ssh.ForwardMessage{} + if err := gossh.Unmarshal(newChan.ExtraData(), &d); err != nil { + err := newChan.Reject(gossh.ConnectionFailed, "Failed to parse channel data") + c.Assert(err, qt.IsNil) + return + } + _, _, err := newChan.Accept() + c.Assert(err, qt.IsNil) + s.testInDestinationServerF(d) + s.received <- true + }, + }, + } + go func() { + _ = s.destinationJujuSSHServer.ListenAndServe() + }() + s.destinationServerPort, err = strconv.Atoi(strings.Split(s.destinationJujuSSHServer.Addr, ":")[1]) + c.Assert(err, qt.IsNil) + + port, err = jimmtest.GetFreePort() + c.Assert(err, qt.IsNil) + s.jumpServerPort = port + s.jumpSSHServer, err = ssh.NewJumpSSHServer(context.Background(), port, resolver{}) + c.Assert(err, qt.IsNil) + go func() { + _ = s.jumpSSHServer.ListenAndServe() + }() + + k, err := rsa.GenerateKey(rand.Reader, 2048) + c.Assert(err, qt.IsNil) + keyPEM := pem.EncodeToMemory( + &pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: x509.MarshalPKCS1PrivateKey(k), + }, + ) + + s.privateKey, err = gossh.ParsePrivateKey(keyPEM) + c.Assert(err, qt.IsNil) + c.Cleanup(func() { + err := s.destinationJujuSSHServer.Close() + c.Check(err, qt.IsNil) + err = s.jumpSSHServer.Close() + c.Check(err, qt.IsNil) + }) +} + +func (s *sshSuite) TestSSHJump(c *qt.C) { + client, err := gossh.Dial("tcp", fmt.Sprintf(":%d", s.jumpServerPort), &gossh.ClientConfig{ + //nolint:gosec // this will be removed once we handle hostkeys + HostKeyCallback: gossh.InsecureIgnoreHostKey(), + Auth: []gossh.AuthMethod{ + gossh.PublicKeys(s.privateKey), + }, + }) + c.Assert(err, qt.IsNil) + defer client.Close() + + // send forward message + msg := ssh.ForwardMessage{ + DestAddr: "model1", + //nolint:gosec + DestPort: uint32(s.destinationServerPort), + SrcAddr: "localhost", + SrcPort: 0, + } + s.testInDestinationServerF = func(fm ssh.ForwardMessage) { + c.Check(fm.DestAddr, qt.Equals, "model1") + } + ch, _, err := client.OpenChannel("direct-tcpip", gossh.Marshal(&msg)) + c.Check(err, qt.IsNil) + defer ch.Close() + select { + case <-s.received: + case <-time.After(100 * time.Millisecond): + c.Fail() + } +} + +func (s *sshSuite) TestSSHJumpDialFail(c *qt.C) { + _, err := gossh.Dial("tcp", fmt.Sprintf(":%d", 1), &gossh.ClientConfig{ + //nolint:gosec // this will be removed once we handle hostkeys + HostKeyCallback: gossh.InsecureIgnoreHostKey(), + Auth: []gossh.AuthMethod{ + gossh.PublicKeys(s.privateKey), + }, + }) + c.Assert(err, qt.ErrorMatches, ".*connect: connection refused.*") +} + +func (s *sshSuite) TestSSHFinalDestinationDialFail(c *qt.C) { + + client, err := gossh.Dial("tcp", fmt.Sprintf(":%d", s.jumpServerPort), &gossh.ClientConfig{ + //nolint:gosec // this will be removed once we handle hostkeys + HostKeyCallback: gossh.InsecureIgnoreHostKey(), + Auth: []gossh.AuthMethod{ + gossh.PublicKeys(s.privateKey), + }, + }) + c.Assert(err, qt.IsNil) + + // send forward message + msg := ssh.ForwardMessage{ + DestAddr: "model1", + //nolint:gosec + DestPort: 1, // the test fails because there is no ssh server on this port. + SrcAddr: "localhost", + SrcPort: 0, + } + s.testInDestinationServerF = func(fm ssh.ForwardMessage) { + c.Check(fm.DestAddr, qt.Equals, "model1") + } + _, _, err = client.OpenChannel("direct-tcpip", gossh.Marshal(&msg)) + c.Assert(err, qt.ErrorMatches, ".*connect failed.*") + +} + +func TestIdentityManager(t *testing.T) { + qtsuite.Run(qt.New(t), &sshSuite{}) +} diff --git a/internal/testutils/jimmtest/utils.go b/internal/testutils/jimmtest/utils.go new file mode 100644 index 000000000..b2519f5d1 --- /dev/null +++ b/internal/testutils/jimmtest/utils.go @@ -0,0 +1,20 @@ +// Copyright 2025 Canonical. + +package jimmtest + +import ( + "errors" + "net" +) + +// GetFreePort asks the kernel for a free open port that is ready to use. +func GetFreePort() (int, error) { + if a, err := net.ResolveTCPAddr("tcp", "localhost:0"); err == nil { + var l *net.TCPListener + if l, err = net.ListenTCP("tcp", a); err == nil { + defer l.Close() + return l.Addr().(*net.TCPAddr).Port, nil + } + } + return 0, errors.New("Couldn't find any free port") +}