Skip to content

Commit

Permalink
Fix websocket dial to ipv6 (#1454)
Browse files Browse the repository at this point in the history
* fix ipv6 dial

* fly by pass the context to dial
  • Loading branch information
SimoneDutto authored Nov 25, 2024
1 parent cb701b8 commit 03071b0
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 2 deletions.
22 changes: 22 additions & 0 deletions internal/rpc/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"crypto/tls"
"crypto/x509"
"encoding/json"
"net"
"net/http"
"net/http/httptest"
"strings"
Expand Down Expand Up @@ -507,6 +508,27 @@ func newServer(f func(*websocket.Conn) error) *server {
return &srv
}

func newIPv6Server(f func(*websocket.Conn) error) *server {
var srv server
l, _ := net.Listen("tcp", "[::1]:0")
server := httptest.Server{
Listener: l,
Config: &http.Server{Handler: handleWS(f)}, //nolint:gosec
}
server.StartTLS()
srv.Server = &server
srv.URL = "ws" + strings.TrimPrefix(srv.Server.URL, "http")
cp := x509.NewCertPool()
cp.AddCert(srv.Certificate())
srv.dialer = &rpc.Dialer{
TLSConfig: &tls.Config{
RootCAs: cp,
MinVersion: tls.VersionTLS12,
},
}
return &srv
}

func handleWS(f func(*websocket.Conn) error) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
var u websocket.Upgrader
Expand Down
66 changes: 66 additions & 0 deletions internal/rpc/dial _test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
// Copyright 2024 Canonical.

package rpc_test

import (
"context"
"encoding/pem"
"net/http"
"testing"

qt "github.com/frankban/quicktest"
"github.com/juju/juju/core/network"
jujuparams "github.com/juju/juju/rpc/params"
"github.com/juju/names/v5"

"github.com/canonical/jimm/v3/internal/dbmodel"
"github.com/canonical/jimm/v3/internal/rpc"
)

func TestDialIPv4(t *testing.T) {
c := qt.New(t)
ctx := context.Background()
fakeController := newServer(echo)
defer fakeController.Close()
controller := dbmodel.Controller{}
pemData := pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE",
Bytes: fakeController.Certificate().Raw,
})
controller.CACertificate = string(pemData)
hp, err := network.ParseMachineHostPort(fakeController.Listener.Addr().String())
c.Assert(err, qt.Equals, nil)
controller.Addresses = append(make([][]jujuparams.HostPort, 0), []jujuparams.HostPort{{
Address: jujuparams.Address{
Value: hp.Value,
Type: "ipv4",
},
Port: hp.Port(),
}})
_, err = rpc.Dial(ctx, &controller, names.ModelTag{}, "", http.Header{})
c.Assert(err, qt.Equals, nil)
}

func TestDialIPv6(t *testing.T) {
c := qt.New(t)
ctx := context.Background()
fakeController := newIPv6Server(echo)
defer fakeController.Close()
controller := dbmodel.Controller{}
pemData := pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE",
Bytes: fakeController.Certificate().Raw,
})
controller.CACertificate = string(pemData)
hp, err := network.ParseMachineHostPort(fakeController.Listener.Addr().String())
c.Assert(err, qt.Equals, nil)
controller.Addresses = append(make([][]jujuparams.HostPort, 0), []jujuparams.HostPort{{
Address: jujuparams.Address{
Value: hp.Value,
Type: "ipv6",
},
Port: hp.Port(),
}})
_, err = rpc.Dial(ctx, &controller, names.ModelTag{}, "", http.Header{})
c.Assert(err, qt.Equals, nil)
}
10 changes: 8 additions & 2 deletions internal/rpc/dial.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ func (d Dialer) DialWebsocket(ctx context.Context, url string, headers http.Head
dialer := websocket.Dialer{
TLSClientConfig: d.TLSConfig,
}
conn, resp, err := dialer.DialContext(context.Background(), url, headers)
conn, resp, err := dialer.DialContext(ctx, url, headers)
if err != nil {
zapctx.Error(ctx, "BasicDial failed", zap.Error(err))
return nil, errors.E(op, err)
Expand Down Expand Up @@ -89,7 +89,13 @@ func Dial(ctx context.Context, ctl *dbmodel.Controller, modelTag names.ModelTag,
for _, hps := range ctl.Addresses {
for _, hp := range hps {
if maybeReachable(hp.Scope) {
urls = append(urls, websocketURL(fmt.Sprintf("%s:%d", hp.Value, hp.Port), modelTag, finalPath))
var ip string
if hp.Type == string(network.IPv6Address) {
ip = fmt.Sprintf("[%s]:%d", hp.Value, hp.Port)
} else {
ip = fmt.Sprintf("%s:%d", hp.Value, hp.Port)
}
urls = append(urls, websocketURL(ip, modelTag, finalPath))
}
}
}
Expand Down

0 comments on commit 03071b0

Please sign in to comment.