Skip to content

Commit

Permalink
Merge pull request #2 from Snawoot/iface_spec
Browse files Browse the repository at this point in the history
Iface spec
  • Loading branch information
Snawoot authored Jan 23, 2024
2 parents 6fdc801 + 3174c81 commit 5c40335
Show file tree
Hide file tree
Showing 5 changed files with 288 additions and 4 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ Configuration example:

```yaml
listen:
- 239.82.71.65:8271
- 239.82.71.65:8271 # or "239.82.71.65:8271@eth0" or "239.82.71.65:8271@192.168.0.0/16"
- 127.0.0.1:8282

groups:
Expand Down
35 changes: 34 additions & 1 deletion agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,14 @@ import (
"fmt"
"log"
"net"
"net/netip"
"strings"
"sync"
"time"

"github.com/Snawoot/rgap/config"
"github.com/Snawoot/rgap/protocol"
"github.com/Snawoot/rgap/util"
"github.com/hashicorp/go-multierror"
)

Expand Down Expand Up @@ -92,7 +95,12 @@ func (a *Agent) singleRun(ctx context.Context, t time.Time) error {
}

func (a *Agent) sendSingle(ctx context.Context, msg []byte, dst string) error {
conn, err := a.cfg.Dialer.DialContext(ctx, "udp", dst)
dstAddr, iface, err := util.SplitAndResolveAddrSpec(dst)
if err != nil {
return fmt.Errorf("destination %s: interface resolving failed: %w", dst, err)
}

conn, err := a.dialInterfaceContext(ctx, "udp", dstAddr, iface)
if err != nil {
return fmt.Errorf("Agent.sendSingle dial failed: %w", err)
}
Expand All @@ -111,3 +119,28 @@ func (a *Agent) sendSingle(ctx context.Context, msg []byte, dst string) error {
}
return nil
}

func (a *Agent) dialInterfaceContext(ctx context.Context, network, addr string, iif *net.Interface) (net.Conn, error) {
if iif == nil {
return a.cfg.Dialer.DialContext(ctx, network, addr)
}

var hints []string
addrs, err := iif.Addrs()
if err != nil {
return nil, err
}
for _, addr := range addrs {
ipnet, ok := addr.(*net.IPNet)
if !ok {
return nil, fmt.Errorf("unexpected type returned as address interface: %T", addr)
}
netipAddr, ok := netip.AddrFromSlice(ipnet.IP)
if !ok {
return nil, fmt.Errorf("interface %v has invalid address %s", iif.Name, ipnet.IP)
}
hints = append(hints, netipAddr.Unmap().String())
}
boundDialer := util.NewBoundDialer(a.cfg.Dialer, strings.Join(hints, ","))
return boundDialer.DialContext(ctx, network, addr)
}
10 changes: 8 additions & 2 deletions listener/udpsource.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"net"

"github.com/Snawoot/rgap/protocol"
"github.com/Snawoot/rgap/util"
)

type UDPSource struct {
Expand All @@ -33,15 +34,20 @@ func (s *UDPSource) Start() error {
s.ctxCancel = cancel
s.loopDone = make(chan struct{})

udpAddr, err := net.ResolveUDPAddr("udp", s.address)
listenAddr, iface, err := util.SplitAndResolveAddrSpec(s.address)
if err != nil {
return fmt.Errorf("UDP source %s: interface resolving failed: %w", s.address, err)
}

udpAddr, err := net.ResolveUDPAddr("udp", listenAddr)
if err != nil {
return fmt.Errorf("bad UDP listen address: %w", err)
}

var conn *net.UDPConn

if udpAddr.IP.IsMulticast() {
conn, err = net.ListenMulticastUDP("udp4", nil, udpAddr)
conn, err = net.ListenMulticastUDP("udp", iface, udpAddr)
if err != nil {
return fmt.Errorf("UDP listen failed: %w", err)
}
Expand Down
186 changes: 186 additions & 0 deletions util/hintdialer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
package util

import (
"context"
"errors"
"fmt"
"net"
"os"
"strings"

"github.com/hashicorp/go-multierror"
)

var (
ErrNoSuitableAddress = errors.New("no suitable address")
ErrBadIPAddressLength = errors.New("bad IP address length")
ErrUnknownNetwork = errors.New("unknown network")
)

type BoundDialerContextKey struct{}

type BoundDialerContextValue struct {
Hints *string
LocalAddr string
}

type BoundDialerDefaultSink interface {
DialContext(ctx context.Context, network, address string) (net.Conn, error)
}

type BoundDialer struct {
defaultDialer BoundDialerDefaultSink
defaultHints string
}

func NewBoundDialer(defaultDialer BoundDialerDefaultSink, defaultHints string) *BoundDialer {
if defaultDialer == nil {
defaultDialer = &net.Dialer{}
}
return &BoundDialer{
defaultDialer: defaultDialer,
defaultHints: defaultHints,
}
}

func (d *BoundDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
hints := d.defaultHints
lAddr := ""
if hintsOverride := ctx.Value(BoundDialerContextKey{}); hintsOverride != nil {
if hintsOverrideValue, ok := hintsOverride.(BoundDialerContextValue); ok {
if hintsOverrideValue.Hints != nil {
hints = *hintsOverrideValue.Hints
}
lAddr = hintsOverrideValue.LocalAddr
}
}

parsedHints, err := parseHints(hints, lAddr)
if err != nil {
return nil, fmt.Errorf("dial failed: %w", err)
}

if len(parsedHints) == 0 {
return d.defaultDialer.DialContext(ctx, network, address)
}

var netBase string
switch network {
case "tcp", "tcp4", "tcp6":
netBase = "tcp"
case "udp", "udp4", "udp6":
netBase = "udp"
case "ip", "ip4", "ip6":
netBase = "ip"
default:
return d.defaultDialer.DialContext(ctx, network, address)
}

var resErr error
for _, lIP := range parsedHints {
lAddr, restrictedNetwork, err := ipToLAddr(netBase, lIP)
if err != nil {
resErr = multierror.Append(resErr, fmt.Errorf("ipToLAddr(%q) failed: %w", lIP.String(), err))
continue
}
if network != netBase && network != restrictedNetwork {
continue
}

conn, err := (&net.Dialer{
LocalAddr: lAddr,
}).DialContext(ctx, restrictedNetwork, address)
if err != nil {
resErr = multierror.Append(resErr, fmt.Errorf("dial failed: %w", err))
} else {
return conn, nil
}
}

if resErr == nil {
resErr = ErrNoSuitableAddress
}
return nil, resErr
}

func (d *BoundDialer) Dial(network, address string) (net.Conn, error) {
return d.DialContext(context.Background(), network, address)
}

func ipToLAddr(network string, ip net.IP) (net.Addr, string, error) {
v6 := true
if ip4 := ip.To4(); len(ip4) == net.IPv4len {
ip = ip4
v6 = false
} else if len(ip) != net.IPv6len {
return nil, "", ErrBadIPAddressLength
}

var lAddr net.Addr
var lNetwork string
switch network {
case "tcp", "tcp4", "tcp6":
lAddr = &net.TCPAddr{
IP: ip,
}
if v6 {
lNetwork = "tcp6"
} else {
lNetwork = "tcp4"
}
case "udp", "udp4", "udp6":
lAddr = &net.UDPAddr{
IP: ip,
}
if v6 {
lNetwork = "udp6"
} else {
lNetwork = "udp4"
}
case "ip", "ip4", "ip6":
lAddr = &net.IPAddr{
IP: ip,
}
if v6 {
lNetwork = "ip6"
} else {
lNetwork = "ip4"
}
default:
return nil, "", ErrUnknownNetwork
}

return lAddr, lNetwork, nil
}

func parseHints(hints, lAddr string) ([]net.IP, error) {
hints = os.Expand(hints, func(key string) string {
switch key {
case "lAddr":
return lAddr
default:
return fmt.Sprintf("<bad key:%q>", key)
}
})
res, err := parseIPList(hints)
if err != nil {
return nil, fmt.Errorf("unable to parse source IP hints %q: %w", hints, err)
}
return res, nil
}

func parseIPList(list string) ([]net.IP, error) {
res := make([]net.IP, 0)
for _, elem := range strings.Split(list, ",") {
elem = strings.TrimSpace(elem)
if len(elem) == 0 {
continue
}
if parsed := net.ParseIP(elem); parsed == nil {
return nil, fmt.Errorf("unable to parse IP address %q", elem)
} else {
res = append(res, parsed)
}
}
return res, nil
}
59 changes: 59 additions & 0 deletions util/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,12 @@ package util

import (
"bytes"
"errors"
"fmt"
"log"
"net"
"net/netip"
"strings"

"gopkg.in/yaml.v3"
)
Expand Down Expand Up @@ -58,3 +62,58 @@ func CheckedUnmarshal(doc *yaml.Node, dst interface{}) error {
}
return nil
}

func SplitAndResolveAddrSpec(spec string) (string, *net.Interface, error) {
addrSpec, ifaceSpec, found := strings.Cut(spec, "@")
if !found {
return addrSpec, nil, nil
}
iface, err := ResolveInterface(ifaceSpec)
if err != nil {
return addrSpec, nil, fmt.Errorf("unable to resolve interface spec %q: %w", ifaceSpec, err)
}
return addrSpec, iface, nil
}

func ResolveInterface(spec string) (*net.Interface, error) {
ifaces, err := net.Interfaces()
if err != nil {
return nil, fmt.Errorf("unable to enumerate interfaces: %w", err)
}
if pfx, err := netip.ParsePrefix(spec); err == nil {
// look for address
for i := range ifaces {
addrs, err := ifaces[i].Addrs()
if err != nil {
// may be a problem with some interface,
// but we still probably can find the right one
log.Printf("WARNING: interface %s is failing to report its addresses: %v", ifaces[i].Name, err)
continue
}
for _, addr := range addrs {
ipnet, ok := addr.(*net.IPNet)
if !ok {
return nil, fmt.Errorf("unexpected type returned as address interface: %T", addr)
}
netipAddr, ok := netip.AddrFromSlice(ipnet.IP)
if !ok {
return nil, fmt.Errorf("interface %v has invalid address %s", ifaces[i].Name, ipnet.IP)
}
netipAddr = netipAddr.Unmap()
if pfx.Contains(netipAddr) {
res := ifaces[i]
return &res, nil
}
}
}
} else {
// look for iface name
for i := range ifaces {
if ifaces[i].Name == spec {
res := ifaces[i]
return &res, nil
}
}
}
return nil, errors.New("specified interface not found")
}

0 comments on commit 5c40335

Please sign in to comment.