diff --git a/dns/config_linux.go b/dns/config_linux.go index e22f86e0..83ed1845 100644 --- a/dns/config_linux.go +++ b/dns/config_linux.go @@ -7,12 +7,15 @@ import ( "os" "slices" "strings" + "sync" "github.com/gravitl/netclient/config" "github.com/gravitl/netclient/ncutils" "golang.org/x/exp/slog" ) +var dnsConfigMutex = sync.Mutex{} // used to mutex functions of the DNS + const ( resolvconfFilePath = "/etc/resolv.conf" resolvconfFileBkpPath = "/etc/netclient/resolv.conf.nm.bkp" @@ -31,7 +34,8 @@ func isStubSupported() bool { // } func SetupDNSConfig() (err error) { - + dnsConfigMutex.Lock() + defer dnsConfigMutex.Unlock() if isStubSupported() { err = setupResolvectl() } else { @@ -42,7 +46,8 @@ func SetupDNSConfig() (err error) { } func RestoreDNSConfig() (err error) { - + dnsConfigMutex.Lock() + defer dnsConfigMutex.Unlock() if isStubSupported() { } else { @@ -99,9 +104,8 @@ func getIpFromServerString(addrStr string) string { return s } -func setupResolveconf() error { +func backupResolveconfFile() error { - // backup /etc/resolv.conf _, err := os.Stat(resolvconfFileBkpPath) if err != nil { src_file, err := os.Open(resolvconfFilePath) @@ -123,39 +127,29 @@ func setupResolveconf() error { return err } } + return nil +} - //get nameserver and search domain - dnsIp := GetDNSServerInstance().AddrStr - if dnsIp == "" { - return errors.New("no listener is running") - } - if len(config.GetNodes()) == 0 { - return errors.New("no network joint") - } +func buildAddConfigContent() ([]string, error) { - domains := "search" - for _, v := range config.GetNodes() { - domains = domains + " " + v.Network + //get nameserver and search domain + ns, domains, err := getNSAndDomains() + if err != nil { + slog.Error("error in getting getNSAndDomains", "error", err.Error()) + return []string{}, err } - domains = domains + " " + config.Netclient().DNSSearch - - dnsIp = getIpFromServerString(dnsIp) - - ns := "nameserver" - ns = ns + " " + dnsIp - // add nameserver and search domain f, err := os.Open(resolvconfFilePath) if err != nil { slog.Error("error opending file", "error", resolvconfFilePath, err.Error()) - return err + return []string{}, err } defer f.Close() rawBytes, err := io.ReadAll(f) if err != nil { slog.Error("error reading file", "error", resolvconfFilePath, err.Error()) - return err + return []string{}, err } lines := strings.Split(string(rawBytes), "\n") lNo := 0 @@ -169,7 +163,26 @@ func setupResolveconf() error { lines = slices.Insert(lines, lNo, ns) lines = slices.Insert(lines, lNo, domains) - f, err = os.OpenFile(resolvconfFilePath, os.O_CREATE|os.O_WRONLY, 0700) + return lines, nil +} + +func setupResolveconf() error { + + // backup /etc/resolv.conf + err := backupResolveconfFile() + if err != nil { + slog.Error("could not backup ", resolvconfFilePath, "error", err.Error()) + return err + } + + // add nameserver and search domain + lines, err := buildAddConfigContent() + if err != nil { + slog.Error("could not build config content", "error", err.Error()) + return err + } + + f, err := os.OpenFile(resolvconfFilePath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0700) if err != nil { slog.Error("error opending file", "error", resolvconfFilePath, err.Error()) return err @@ -187,33 +200,75 @@ func setupResolveconf() error { return nil } -func restoreResolveconf() error { +func getNSAndDomains() (string, string, error) { + + dnsIp := GetDNSServerInstance().AddrStr + if dnsIp == "" { + return "", "", errors.New("no listener is running") + } + if len(config.GetNodes()) == 0 { + return "", "", errors.New("no network joint") + } + + domains := "search" + for _, v := range config.GetNodes() { + domains = domains + " " + v.Network + } + domains = domains + " " + config.Netclient().DNSSearch + + dnsIp = getIpFromServerString(dnsIp) + + ns := "nameserver" + ns = ns + " " + dnsIp + + return ns, domains, nil +} +func buildDeleteConfigContent() ([]string, error) { f, err := os.Open(resolvconfFilePath) if err != nil { slog.Error("error opending file", "error", resolvconfFilePath, err.Error()) - return err + return []string{}, err } defer f.Close() rawBytes, err := io.ReadAll(f) if err != nil { slog.Error("error reading file", "error", resolvconfFilePath, err.Error()) - return err + return []string{}, err } lines := strings.Split(string(rawBytes), "\n") + + //get nameserver and search domain + _, domains, err := getNSAndDomains() + if err != nil { + slog.Error("error in getting getNSAndDomains", "error", err.Error()) + return []string{}, err + } + lNo := 0 for i, line := range lines { - if strings.HasPrefix(line, "nameserver") { + if strings.Contains(line, domains) { lNo = i break } } - //Delete the added search and nameserver two lines - lines = slices.Delete(lines, lNo-1, lNo) + lines = slices.Delete(lines, lNo, lNo+1) + lines = slices.Delete(lines, lNo, lNo+1) + + return lines, nil +} + +func restoreResolveconf() error { + + lines, err := buildDeleteConfigContent() + if err != nil { + slog.Error("could not build config content", "error", err.Error()) + return err + } - f, err = os.OpenFile(resolvconfFilePath, os.O_CREATE|os.O_WRONLY, 0700) + f, err := os.OpenFile(resolvconfFilePath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0700) if err != nil { slog.Error("error opending file", "error", resolvconfFilePath, err.Error()) return err @@ -239,7 +294,8 @@ func restoreResolveconf() error { } func InitDNSConfig() { - + dnsConfigMutex.Lock() + defer dnsConfigMutex.Unlock() f, err := os.Open(resolvconfFilePath) if err != nil { slog.Error("error opending file", "error", resolvconfFilePath, err.Error()) diff --git a/dns/listener.go b/dns/listener.go index 9b6e6f7e..93fcdd6c 100644 --- a/dns/listener.go +++ b/dns/listener.go @@ -97,6 +97,14 @@ func (dnsServer *DNSServer) Stop() { return } + //restore DNS config for Linux + if config.Netclient().Host.OS == "linux" { + err := RestoreDNSConfig() + if err != nil { + slog.Error("Restore DNS conig failed", "error", err.Error()) + } + } + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() @@ -107,12 +115,4 @@ func (dnsServer *DNSServer) Stop() { dnsServer.AddrStr = "" dnsServer.DnsServer = nil - - //restore DNS config for Linux - if config.Netclient().Host.OS == "linux" { - err := RestoreDNSConfig() - if err != nil { - slog.Error("Restore DNS conig failed", "error", err.Error()) - } - } } diff --git a/dns/resolver.go b/dns/resolver.go index 5d69a5b6..cf007d94 100644 --- a/dns/resolver.go +++ b/dns/resolver.go @@ -56,8 +56,10 @@ func handleDNSRequest(w dns.ResponseWriter, r *dns.Msg) { continue } - reply.Answer = append(reply.Answer, resp.Answer[0]) - break + if len(resp.Answer) > 0 { + reply.Answer = append(reply.Answer, resp.Answer[0]) + break + } } } else { reply.Rcode = dns.RcodeNameError diff --git a/functions/mqhandlers.go b/functions/mqhandlers.go index e740d5ad..a3fc45b8 100644 --- a/functions/mqhandlers.go +++ b/functions/mqhandlers.go @@ -398,7 +398,7 @@ func resetInterfaceFunc() { wireguard.SetRoutesFromCache() //Setup resolveconf for Linux - if config.Netclient().Host.OS == "linux" && dns.GetDNSServerInstance().AddrStr != "" { + if config.Netclient().Host.OS == "linux" && dns.GetDNSServerInstance().AddrStr != "" && (config.Netclient().DNSManagerType == dns.DNS_MANAGER_STUB || config.Netclient().DNSManagerType == dns.DNS_MANAGER_UPLINK) { dns.SetupDNSConfig() } }