Skip to content

Commit

Permalink
Fix DNS resolution for routes on iOS (#2378)
Browse files Browse the repository at this point in the history
  • Loading branch information
pascal-fischer authored Aug 2, 2024
1 parent 727a4f0 commit 501fd93
Show file tree
Hide file tree
Showing 8 changed files with 124 additions and 37 deletions.
6 changes: 3 additions & 3 deletions client/internal/dns/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ func NewDefaultServer(

var dnsService service
if wgInterface.IsUserspaceBind() {
dnsService = newServiceViaMemory(wgInterface)
dnsService = NewServiceViaMemory(wgInterface)
} else {
dnsService = newServiceViaListener(wgInterface, addrPort)
}
Expand All @@ -112,7 +112,7 @@ func NewDefaultServerPermanentUpstream(
statusRecorder *peer.Status,
) *DefaultServer {
log.Debugf("host dns address list is: %v", hostsDnsList)
ds := newDefaultServer(ctx, wgInterface, newServiceViaMemory(wgInterface), statusRecorder)
ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder)
ds.hostsDNSHolder.set(hostsDnsList)
ds.permanent = true
ds.addHostRootZone()
Expand All @@ -130,7 +130,7 @@ func NewDefaultServerIos(
iosDnsManager IosDnsManager,
statusRecorder *peer.Status,
) *DefaultServer {
ds := newDefaultServer(ctx, wgInterface, newServiceViaMemory(wgInterface), statusRecorder)
ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder)
ds.iosDnsManager = iosDnsManager
return ds
}
Expand Down
2 changes: 1 addition & 1 deletion client/internal/dns/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -534,7 +534,7 @@ func TestDNSServerStartStop(t *testing.T) {
func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
hostManager := &mockHostConfigurator{}
server := DefaultServer{
service: newServiceViaMemory(&mocWGIface{}),
service: NewServiceViaMemory(&mocWGIface{}),
localResolver: &localResolver{
registeredMap: make(registrationMap),
},
Expand Down
20 changes: 10 additions & 10 deletions client/internal/dns/service_memory.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import (
log "github.com/sirupsen/logrus"
)

type serviceViaMemory struct {
type ServiceViaMemory struct {
wgInterface WGIface
dnsMux *dns.ServeMux
runtimeIP string
Expand All @@ -22,8 +22,8 @@ type serviceViaMemory struct {
listenerFlagLock sync.Mutex
}

func newServiceViaMemory(wgIface WGIface) *serviceViaMemory {
s := &serviceViaMemory{
func NewServiceViaMemory(wgIface WGIface) *ServiceViaMemory {
s := &ServiceViaMemory{
wgInterface: wgIface,
dnsMux: dns.NewServeMux(),

Expand All @@ -33,7 +33,7 @@ func newServiceViaMemory(wgIface WGIface) *serviceViaMemory {
return s
}

func (s *serviceViaMemory) Listen() error {
func (s *ServiceViaMemory) Listen() error {
s.listenerFlagLock.Lock()
defer s.listenerFlagLock.Unlock()

Expand All @@ -52,7 +52,7 @@ func (s *serviceViaMemory) Listen() error {
return nil
}

func (s *serviceViaMemory) Stop() {
func (s *ServiceViaMemory) Stop() {
s.listenerFlagLock.Lock()
defer s.listenerFlagLock.Unlock()

Expand All @@ -67,23 +67,23 @@ func (s *serviceViaMemory) Stop() {
s.listenerIsRunning = false
}

func (s *serviceViaMemory) RegisterMux(pattern string, handler dns.Handler) {
func (s *ServiceViaMemory) RegisterMux(pattern string, handler dns.Handler) {
s.dnsMux.Handle(pattern, handler)
}

func (s *serviceViaMemory) DeregisterMux(pattern string) {
func (s *ServiceViaMemory) DeregisterMux(pattern string) {
s.dnsMux.HandleRemove(pattern)
}

func (s *serviceViaMemory) RuntimePort() int {
func (s *ServiceViaMemory) RuntimePort() int {
return s.runtimePort
}

func (s *serviceViaMemory) RuntimeIP() string {
func (s *ServiceViaMemory) RuntimeIP() string {
return s.runtimeIP
}

func (s *serviceViaMemory) filterDNSTraffic() (string, error) {
func (s *ServiceViaMemory) filterDNSTraffic() (string, error) {
filter := s.wgInterface.GetFilter()
if filter == nil {
return "", fmt.Errorf("can't set DNS filter, filter not initialized")
Expand Down
38 changes: 21 additions & 17 deletions client/internal/dns/upstream_ios.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ package dns

import (
"context"
"fmt"
"net"
"syscall"
"time"
Expand All @@ -17,9 +18,9 @@ import (

type upstreamResolverIOS struct {
*upstreamResolverBase
lIP net.IP
lNet *net.IPNet
iIndex int
lIP net.IP
lNet *net.IPNet
interfaceName string
}

func newUpstreamResolver(
Expand All @@ -32,17 +33,11 @@ func newUpstreamResolver(
) (*upstreamResolverIOS, error) {
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder)

index, err := getInterfaceIndex(interfaceName)
if err != nil {
log.Debugf("unable to get interface index for %s: %s", interfaceName, err)
return nil, err
}

ios := &upstreamResolverIOS{
upstreamResolverBase: upstreamResolverBase,
lIP: ip,
lNet: net,
iIndex: index,
interfaceName: interfaceName,
}
ios.upstreamClient = ios

Expand All @@ -53,7 +48,7 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *
client := &dns.Client{}
upstreamHost, _, err := net.SplitHostPort(upstream)
if err != nil {
log.Errorf("error while parsing upstream host: %s", err)
return nil, 0, fmt.Errorf("error while parsing upstream host: %s", err)
}

timeout := upstreamTimeout
Expand All @@ -65,26 +60,35 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *
upstreamIP := net.ParseIP(upstreamHost)
if u.lNet.Contains(upstreamIP) || net.IP.IsPrivate(upstreamIP) {
log.Debugf("using private client to query upstream: %s", upstream)
client = u.getClientPrivate(timeout)
client, err = GetClientPrivate(u.lIP, u.interfaceName, timeout)
if err != nil {
return nil, 0, fmt.Errorf("error while creating private client: %s", err)
}
}

// Cannot use client.ExchangeContext because it overwrites our Dialer
return client.Exchange(r, upstream)
}

// getClientPrivate returns a new DNS client bound to the local IP address of the Netbird interface
// GetClientPrivate returns a new DNS client bound to the local IP address of the Netbird interface
// This method is needed for iOS
func (u *upstreamResolverIOS) getClientPrivate(dialTimeout time.Duration) *dns.Client {
func GetClientPrivate(ip net.IP, interfaceName string, dialTimeout time.Duration) (*dns.Client, error) {
index, err := getInterfaceIndex(interfaceName)
if err != nil {
log.Debugf("unable to get interface index for %s: %s", interfaceName, err)
return nil, err
}

dialer := &net.Dialer{
LocalAddr: &net.UDPAddr{
IP: u.lIP,
IP: ip,
Port: 0, // Let the OS pick a free port
},
Timeout: dialTimeout,
Control: func(network, address string, c syscall.RawConn) error {
var operr error
fn := func(s uintptr) {
operr = unix.SetsockoptInt(int(s), unix.IPPROTO_IP, unix.IP_BOUND_IF, u.iIndex)
operr = unix.SetsockoptInt(int(s), unix.IPPROTO_IP, unix.IP_BOUND_IF, index)
}

if err := c.Control(fn); err != nil {
Expand All @@ -101,7 +105,7 @@ func (u *upstreamResolverIOS) getClientPrivate(dialTimeout time.Duration) *dns.C
client := &dns.Client{
Dialer: dialer,
}
return client
return client, nil
}

func getInterfaceIndex(interfaceName string) (int, error) {
Expand Down
8 changes: 5 additions & 3 deletions client/internal/routemanager/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
log "github.com/sirupsen/logrus"

nberrors "github.com/netbirdio/netbird/client/errors"
nbdns "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/routemanager/dynamic"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
Expand Down Expand Up @@ -65,7 +66,7 @@ func newClientNetworkWatcher(ctx context.Context, dnsRouteInterval time.Duration
routePeersNotifiers: make(map[string]chan struct{}),
routeUpdate: make(chan routesUpdate),
peerStateUpdate: make(chan struct{}),
handler: handlerFromRoute(rt, routeRefCounter, allowedIPsRefCounter, dnsRouteInterval, statusRecorder),
handler: handlerFromRoute(rt, routeRefCounter, allowedIPsRefCounter, dnsRouteInterval, statusRecorder, wgInterface),
}
return client
}
Expand Down Expand Up @@ -383,9 +384,10 @@ func (c *clientNetwork) peersStateAndUpdateWatcher() {
}
}

func handlerFromRoute(rt *route.Route, routeRefCounter *refcounter.RouteRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter, dnsRouterInteval time.Duration, statusRecorder *peer.Status) RouteHandler {
func handlerFromRoute(rt *route.Route, routeRefCounter *refcounter.RouteRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter, dnsRouterInteval time.Duration, statusRecorder *peer.Status, wgInterface *iface.WGIface) RouteHandler {
if rt.IsDynamic() {
return dynamic.NewRoute(rt, routeRefCounter, allowedIPsRefCounter, dnsRouterInteval, statusRecorder)
dns := nbdns.NewServiceViaMemory(wgInterface)
return dynamic.NewRoute(rt, routeRefCounter, allowedIPsRefCounter, dnsRouterInteval, statusRecorder, wgInterface, fmt.Sprintf("%s:%d", dns.RuntimeIP(), dns.RuntimePort()))
}
return static.NewRoute(rt, routeRefCounter, allowedIPsRefCounter)
}
19 changes: 16 additions & 3 deletions client/internal/routemanager/dynamic/route.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/client/internal/routemanager/util"
"github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/route"
)
Expand Down Expand Up @@ -47,6 +48,8 @@ type Route struct {
currentPeerKey string
cancel context.CancelFunc
statusRecorder *peer.Status
wgInterface *iface.WGIface
resolverAddr string
}

func NewRoute(
Expand All @@ -55,6 +58,8 @@ func NewRoute(
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter,
interval time.Duration,
statusRecorder *peer.Status,
wgInterface *iface.WGIface,
resolverAddr string,
) *Route {
return &Route{
route: rt,
Expand All @@ -63,6 +68,8 @@ func NewRoute(
interval: interval,
dynamicDomains: domainMap{},
statusRecorder: statusRecorder,
wgInterface: wgInterface,
resolverAddr: resolverAddr,
}
}

Expand Down Expand Up @@ -228,11 +235,17 @@ func (r *Route) resolve(results chan resolveResult) {
wg.Add(1)
go func(domain domain.Domain) {
defer wg.Done()
ips, err := net.LookupIP(string(domain))

ips, err := r.getIPsFromResolver(domain)
if err != nil {
results <- resolveResult{domain: domain, err: fmt.Errorf("resolve d %s: %w", domain.SafeString(), err)}
return
log.Tracef("Failed to resolve domain %s with private resolver: %v", domain.SafeString(), err)
ips, err = net.LookupIP(string(domain))
if err != nil {
results <- resolveResult{domain: domain, err: fmt.Errorf("resolve d %s: %w", domain.SafeString(), err)}
return
}
}

for _, ip := range ips {
prefix, err := util.GetPrefixFromIP(ip)
if err != nil {
Expand Down
13 changes: 13 additions & 0 deletions client/internal/routemanager/dynamic/route_generic.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
//go:build !ios

package dynamic

import (
"net"

"github.com/netbirdio/netbird/management/domain"
)

func (r *Route) getIPsFromResolver(domain domain.Domain) ([]net.IP, error) {
return net.LookupIP(string(domain))
}
55 changes: 55 additions & 0 deletions client/internal/routemanager/dynamic/route_ios.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
//go:build ios

package dynamic

import (
"fmt"
"net"
"time"

"github.com/miekg/dns"

nbdns "github.com/netbirdio/netbird/client/internal/dns"

"github.com/netbirdio/netbird/management/domain"
)

const dialTimeout = 10 * time.Second

func (r *Route) getIPsFromResolver(domain domain.Domain) ([]net.IP, error) {
privateClient, err := nbdns.GetClientPrivate(r.wgInterface.Address().IP, r.wgInterface.Name(), dialTimeout)
if err != nil {
return nil, fmt.Errorf("error while creating private client: %s", err)
}

msg := new(dns.Msg)
msg.SetQuestion(dns.Fqdn(string(domain)), dns.TypeA)

startTime := time.Now()

response, _, err := privateClient.Exchange(msg, r.resolverAddr)
if err != nil {
return nil, fmt.Errorf("DNS query for %s failed after %s: %s ", domain.SafeString(), time.Since(startTime), err)
}

if response.Rcode != dns.RcodeSuccess {
return nil, fmt.Errorf("dns response code: %s", dns.RcodeToString[response.Rcode])
}

ips := make([]net.IP, 0)

for _, answ := range response.Answer {
if aRecord, ok := answ.(*dns.A); ok {
ips = append(ips, aRecord.A)
}
if aaaaRecord, ok := answ.(*dns.AAAA); ok {
ips = append(ips, aaaaRecord.AAAA)
}
}

if len(ips) == 0 {
return nil, fmt.Errorf("no A or AAAA records found for %s", domain.SafeString())
}

return ips, nil
}

0 comments on commit 501fd93

Please sign in to comment.