diff --git a/src/modules/mxlookup/mx_lookup.go b/src/modules/mxlookup/mx_lookup.go index d89f0933..1d853cc2 100644 --- a/src/modules/mxlookup/mx_lookup.go +++ b/src/modules/mxlookup/mx_lookup.go @@ -15,14 +15,12 @@ package mxlookup import ( "strings" - "sync" "github.com/pkg/errors" log "github.com/sirupsen/logrus" "github.com/zmap/dns" "github.com/zmap/zdns/src/cli" - "github.com/zmap/zdns/src/internal/cachehash" "github.com/zmap/zdns/src/zdns" ) @@ -51,11 +49,8 @@ type MXResult struct { } type MXLookupModule struct { - IPv4Lookup bool `long:"ipv4-lookup" description:"perform A lookups for each MX server"` - IPv6Lookup bool `long:"ipv6-lookup" description:"perform AAAA record lookups for each MX server"` - MXCacheSize int `long:"mx-cache-size" default:"1000" description:"number of records to store in MX -> A/AAAA cache"` - CacheHash *cachehash.CacheHash - CHmu sync.Mutex + IPv4Lookup bool `long:"ipv4-lookup" description:"perform A lookups for each MX server"` + IPv6Lookup bool `long:"ipv6-lookup" description:"perform AAAA record lookups for each MX server"` cli.BasicLookupModule } @@ -77,31 +72,15 @@ func (mxMod *MXLookupModule) Init() { if !mxMod.IPv4Lookup && !mxMod.IPv6Lookup { log.Fatal("At least one of ipv4-lookup or ipv6-lookup must be true") } - if mxMod.MXCacheSize <= 0 { - log.Fatal("mxCacheSize must be greater than 0, got ", mxMod.MXCacheSize) - } - mxMod.CacheHash = new(cachehash.CacheHash) - mxMod.CacheHash.Init(mxMod.MXCacheSize) } func (mxMod *MXLookupModule) lookupIPs(r *zdns.Resolver, name string, nameServer *zdns.NameServer, ipMode zdns.IPVersionMode) (CachedAddresses, zdns.Trace) { - mxMod.CHmu.Lock() - // TODO - Phillip this comment V is present in the original code and has been there since 2017 IIRC, so ask Zakir what to do - // XXX this should be changed to a miekglookup - res, found := mxMod.CacheHash.Get(name) - mxMod.CHmu.Unlock() - if found { - return res.(CachedAddresses), zdns.Trace{} - } retv := CachedAddresses{} result, trace, status, _ := r.DoTargetedLookup(name, nameServer, mxMod.IsIterative, mxMod.IPv4Lookup, mxMod.IPv6Lookup) if status == zdns.StatusNoError && result != nil { retv.IPv4Addresses = result.IPv4Addresses retv.IPv6Addresses = result.IPv6Addresses } - mxMod.CHmu.Lock() - mxMod.CacheHash.Upsert(name, retv) - mxMod.CHmu.Unlock() return retv, trace } diff --git a/src/zdns/cache.go b/src/zdns/cache.go index f5ff6daa..b2b42c3a 100644 --- a/src/zdns/cache.go +++ b/src/zdns/cache.go @@ -30,6 +30,11 @@ type TimedAnswer struct { ExpiresAt time.Time } +type CachedKey struct { + Question Question + NameServer string // optional +} + type CachedResult struct { Answers map[interface{}]TimedAnswer } @@ -46,7 +51,7 @@ func (s *Cache) VerboseLog(depth int, args ...interface{}) { log.Debug(makeVerbosePrefix(depth), args) } -func (s *Cache) AddCachedAnswer(answer interface{}, depth int) { +func (s *Cache) AddCachedAnswer(answer interface{}, ns *NameServer, depth int) { a, ok := answer.(Answer) if !ok { // we can't cache this entry because we have no idea what to name it @@ -64,13 +69,17 @@ func (s *Cache) AddCachedAnswer(answer interface{}, depth int) { return } expiresAt := time.Now().Add(time.Duration(a.TTL) * time.Second) - s.IterativeCache.Lock(q) - defer s.IterativeCache.Unlock(q) - // don't bother to move this to the top of the linked list. we're going - // to add this record back in momentarily and that will take care of this ca := CachedResult{} ca.Answers = make(map[interface{}]TimedAnswer) - i, ok := s.IterativeCache.GetNoMove(q) + cacheKey := CachedKey{q, ""} + if ns != nil { + cacheKey.NameServer = ns.String() + } + s.IterativeCache.Lock(cacheKey) + defer s.IterativeCache.Unlock(cacheKey) + // don't bother to move this to the top of the linked list. we're going + // to add this record back in momentarily and that will take care of this + i, ok := s.IterativeCache.GetNoMove(cacheKey) if ok { // record found, check type on interface ca, ok = i.(CachedResult) @@ -83,18 +92,24 @@ func (s *Cache) AddCachedAnswer(answer interface{}, depth int) { Answer: answer, ExpiresAt: expiresAt} ca.Answers[a] = ta - s.IterativeCache.Add(q, ca) + s.IterativeCache.Add(cacheKey, ca) s.VerboseLog(depth+1, "Upsert cached answer ", q, " ", ca) } -func (s *Cache) GetCachedResult(q Question, isAuthCheck bool, depth int) (SingleQueryResult, bool) { - s.VerboseLog(depth+1, "Cache request for: ", q.Name, " (", q.Type, ")") +func (s *Cache) GetCachedResult(q Question, ns *NameServer, depth int) (SingleQueryResult, bool) { var retv SingleQueryResult - s.IterativeCache.Lock(q) - unres, ok := s.IterativeCache.Get(q) + cacheKey := CachedKey{q, ""} + if ns != nil { + cacheKey.NameServer = ns.String() + s.VerboseLog(depth+1, "Cache request for: ", q.Name, " (", q.Type, ") @", cacheKey.NameServer) + } else { + s.VerboseLog(depth+1, "Cache request for: ", q.Name, " (", q.Type, ")") + } + s.IterativeCache.Lock(cacheKey) + unres, ok := s.IterativeCache.Get(cacheKey) if !ok { // nothing found s.VerboseLog(depth+2, "-> no entry found in cache") - s.IterativeCache.Unlock(q) + s.IterativeCache.Unlock(cacheKey) return retv, false } retv.Authorities = make([]interface{}, 0) @@ -116,26 +131,25 @@ func (s *Cache) GetCachedResult(q Question, isAuthCheck bool, depth int) (Single delete(cachedRes.Answers, k) } else { // this result is valid. append it to the SingleQueryResult we're going to hand to the user - if isAuthCheck { - retv.Authorities = append(retv.Authorities, cachedAnswer.Answer) - } else { - retv.Answers = append(retv.Answers, cachedAnswer.Answer) - } + retv.Answers = append(retv.Answers, cachedAnswer.Answer) } } - s.IterativeCache.Unlock(q) + s.IterativeCache.Unlock(cacheKey) // Don't return an empty response. if len(retv.Answers) == 0 && len(retv.Authorities) == 0 && len(retv.Additional) == 0 { s.VerboseLog(depth+2, "-> no entry found in cache, after expiration") var emptyRetv SingleQueryResult return emptyRetv, false } + if ns != nil { + retv.Resolver = ns.String() + } s.VerboseLog(depth+2, "Cache hit: ", retv) return retv, true } -func (s *Cache) SafeAddCachedAnswer(a interface{}, layer string, debugType string, depth int) { +func (s *Cache) SafeAddCachedAnswer(a interface{}, ns *NameServer, layer string, debugType string, depth int) { ans, ok := a.(Answer) if !ok { s.VerboseLog(depth+1, "unable to cast ", debugType, ": ", layer, ": ", a) @@ -145,19 +159,19 @@ func (s *Cache) SafeAddCachedAnswer(a interface{}, layer string, debugType strin log.Info("detected poison ", debugType, ": ", ans.Name, "(", ans.Type, "): ", layer, ": ", a) return } - s.AddCachedAnswer(a, depth) + s.AddCachedAnswer(a, ns, depth) } -func (s *Cache) CacheUpdate(layer string, result SingleQueryResult, depth int) { +func (s *Cache) CacheUpdate(layer string, result SingleQueryResult, ns *NameServer, depth int, cacheNonAuthoritativeAns bool) { for _, a := range result.Additional { - s.SafeAddCachedAnswer(a, layer, "additional", depth) + s.SafeAddCachedAnswer(a, ns, layer, "additional", depth) } for _, a := range result.Authorities { - s.SafeAddCachedAnswer(a, layer, "authority", depth) + s.SafeAddCachedAnswer(a, ns, layer, "authority", depth) } - if result.Flags.Authoritative { + if result.Flags.Authoritative || cacheNonAuthoritativeAns { for _, a := range result.Answers { - s.SafeAddCachedAnswer(a, layer, "answer", depth) + s.SafeAddCachedAnswer(a, ns, layer, "answer", depth) } } } diff --git a/src/zdns/cache_test.go b/src/zdns/cache_test.go new file mode 100644 index 00000000..8c837ba3 --- /dev/null +++ b/src/zdns/cache_test.go @@ -0,0 +1,148 @@ +/* + * ZDNS Copyright 2024 Regents of the University of Michigan + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy + * of the License at http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. See the License for the specific language governing + * permissions and limitations under the License. + */ +package zdns + +import ( + "net" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestCheckForNonExistentKey(t *testing.T) { + cache := Cache{} + cache.Init(4096) + _, found := cache.GetCachedResult(Question{1, 1, "google.com"}, nil, 0) + assert.False(t, found, "Expected no cache entry") +} + +func TestNoNameServerLookupSuccess(t *testing.T) { + res := SingleQueryResult{ + Answers: []interface{}{Answer{ + TTL: 3600, + RrType: 1, + RrClass: 1, + Name: "google.com.", + Answer: "192.0.2.1", + }}, + Additional: nil, + Authorities: nil, + Protocol: "", + Flags: DNSFlags{Authoritative: true}, + } + cache := Cache{} + cache.Init(4096) + cache.CacheUpdate(".", res, nil, 0, false) + _, found := cache.GetCachedResult(Question{1, 1, "google.com."}, nil, 0) + assert.True(t, found, "Expected cache entry") +} + +func TestNoNameServerLookupForNamedNameServer(t *testing.T) { + res := SingleQueryResult{ + Answers: []interface{}{Answer{ + TTL: 3600, + RrType: 1, + RrClass: 1, + Name: "google.com.", + Answer: "192.0.2.1", + }}, + Additional: nil, + Authorities: nil, + Protocol: "", + Flags: DNSFlags{Authoritative: true}, + } + cache := Cache{} + cache.Init(4096) + cache.CacheUpdate(".", res, nil, 0, false) + _, found := cache.GetCachedResult(Question{1, 1, "google.com."}, &NameServer{ + IP: net.ParseIP("1.1.1.1"), + Port: 53, + }, 0) + assert.False(t, found, "Cache has an answer from a generic nameserver, we wanted a specific one. Shouldn't be found.") +} + +func TestNamedServerLookupForNonNamedNameServer(t *testing.T) { + res := SingleQueryResult{ + Answers: []interface{}{Answer{ + TTL: 3600, + RrType: 1, + RrClass: 1, + Name: "google.com.", + Answer: "192.0.2.1", + }}, + Additional: nil, + Authorities: nil, + Protocol: "", + Flags: DNSFlags{Authoritative: true}, + } + cache := Cache{} + cache.Init(4096) + cache.CacheUpdate(".", res, &NameServer{ + IP: net.ParseIP("1.1.1.1"), + Port: 53, + }, 0, false) + _, found := cache.GetCachedResult(Question{1, 1, "google.com."}, nil, 0) + assert.False(t, found, "Cache has an answer from a named nameserver, we wanted a generic one. Shouldn't be found.") +} + +func TestNamedServerLookupForNamedNameServer(t *testing.T) { + res := SingleQueryResult{ + Answers: []interface{}{Answer{ + TTL: 3600, + RrType: 1, + RrClass: 1, + Name: "google.com.", + Answer: "192.0.2.1", + }}, + Additional: nil, + Authorities: nil, + Protocol: "", + Flags: DNSFlags{Authoritative: true}, + } + cache := Cache{} + cache.Init(4096) + cache.CacheUpdate(".", res, &NameServer{ + IP: net.ParseIP("1.1.1.1"), + Port: 53, + }, 0, false) + _, found := cache.GetCachedResult(Question{1, 1, "google.com."}, &NameServer{ + IP: net.ParseIP("1.1.1.1"), + Port: 53, + }, 0) + assert.True(t, found, "Should be found") +} + +func TestNoNameServerLookupNotAuthoritative(t *testing.T) { + res := SingleQueryResult{ + Answers: []interface{}{Answer{ + TTL: 3600, + RrType: 1, + RrClass: 1, + Name: "google.com.", + Answer: "192.0.2.1", + }}, + Additional: nil, + Authorities: nil, + Protocol: "", + Flags: DNSFlags{Authoritative: false}, + } + cache := Cache{} + cache.Init(4096) + cache.CacheUpdate(".", res, nil, 0, false) + _, found := cache.GetCachedResult(Question{1, 1, "google.com."}, nil, 0) + assert.False(t, found, "shouldn't cache non-authoritative answers") + cache.CacheUpdate(".", res, nil, 0, true) + _, found = cache.GetCachedResult(Question{1, 1, "google.com."}, nil, 0) + assert.True(t, found, "should cache non-authoritative answers") +} diff --git a/src/zdns/lookup.go b/src/zdns/lookup.go index aea263d7..383e8ebe 100644 --- a/src/zdns/lookup.go +++ b/src/zdns/lookup.go @@ -136,7 +136,7 @@ func (r *Resolver) lookup(ctx context.Context, q Question, nameServer *NameServe tries := 0 // external lookup r.verboseLog(1, "MIEKG-IN: following external lookup for ", q.Name, " (", q.Type, ")") - res, status, tries, err = r.retryingLookup(ctx, q, nameServer, true) + res, _, status, tries, err = r.cachedRetryingLookup(ctx, q, nameServer, q.Name, 1, true, true, true) r.verboseLog(1, "MIEKG-OUT: following external lookup for ", q.Name, " (", q.Type, ") with ", tries, " attempts: status: ", status, " , err: ", err) var t TraceStep t.Result = res @@ -343,7 +343,7 @@ func (r *Resolver) iterativeLookup(ctx context.Context, q Question, nameServer * // create iteration context for this iteration step iterationStepCtx, cancel := context.WithTimeout(ctx, r.iterativeTimeout) defer cancel() - result, isCached, status, try, err := r.cachedRetryingLookup(iterationStepCtx, q, nameServer, layer, depth) + result, isCached, status, try, err := r.cachedRetryingLookup(iterationStepCtx, q, nameServer, layer, depth, false, false, false) if status == StatusNoError { var t TraceStep t.Result = result @@ -391,15 +391,39 @@ func (r *Resolver) iterativeLookup(ctx context.Context, q Question, nameServer * } } -func (r *Resolver) cachedRetryingLookup(ctx context.Context, q Question, nameServer *NameServer, layer string, depth int) (SingleQueryResult, IsCached, Status, int, error) { +// cachedRetryingLookup wraps around retryingLookup to perform a DNS lookup with caching +// returns the result, whether it was cached, the status, the number of tries, and an error if one occured +// layer is the domain name layer we're currently querying ex: ".", "com.", "example.com." +// depth is the current depth of the lookup, used for iterative lookups +// requestIteration is whether to set the "recursion desired" bit in the DNS query +// cacheBasedOnNameServer is whether to consider a cache hit based on DNS question and nameserver, or just question +// cacheNonAuthoritative is whether to cache non-authoritative answers, usually used for lookups using an external resolver +func (r *Resolver) cachedRetryingLookup(ctx context.Context, q Question, nameServer *NameServer, layer string, depth int, requestIteration, cacheBasedOnNameServer, cacheNonAuthoritative bool) (SingleQueryResult, IsCached, Status, int, error) { var isCached IsCached isCached = false r.verboseLog(depth+1, "Cached retrying lookup. Name: ", q, ", Layer: ", layer, ", Nameserver: ", nameServer) + // For some lookups, we want them to be nameserver specific, ie. if cacheBasedOnNameServer is true + // Else, we don't care which nameserver returned it + cacheNameServer := nameServer + if !cacheBasedOnNameServer { + cacheNameServer = nil + } // First, we check the answer - cachedResult, ok := r.cache.GetCachedResult(q, false, depth+1) + cachedResult, ok := r.cache.GetCachedResult(q, cacheNameServer, depth+1) if ok { isCached = true + // set protocol on the result + if r.dnsOverHTTPSEnabled { + cachedResult.Protocol = DoHProtocol + } else if r.dnsOverTLSEnabled { + cachedResult.Protocol = DoTProtocol + } else if r.transportMode == TCPOnly { + cachedResult.Protocol = TCPProtocol + } else { + // default to UDP + cachedResult.Protocol = UDPProtocol + } return cachedResult, isCached, StatusNoError, 0, nil } @@ -415,9 +439,9 @@ func (r *Resolver) cachedRetryingLookup(ctx context.Context, q Question, nameSer } // Alright, we're not sure what to do, go to the wire. - result, status, try, err := r.retryingLookup(ctx, q, nameServer, false) + result, status, try, err := r.retryingLookup(ctx, q, nameServer, requestIteration) - r.cache.CacheUpdate(layer, result, depth+2) + r.cache.CacheUpdate(layer, result, cacheNameServer, depth+2, cacheNonAuthoritative) return result, isCached, status, try, err } @@ -535,7 +559,7 @@ func doDoTLookup(ctx context.Context, connInfo *ConnectionInfo, q Question, name } res := SingleQueryResult{ Resolver: connInfo.tlsConn.Conn.RemoteAddr().String(), - Protocol: "DoT", + Protocol: DoTProtocol, Answers: []interface{}{}, Authorities: []interface{}{}, Additional: []interface{}{}, @@ -607,7 +631,7 @@ func doDoHLookup(ctx context.Context, httpClient *http.Client, q Question, nameS } res := SingleQueryResult{ Resolver: nameServer.DomainName, - Protocol: "DoH", + Protocol: DoHProtocol, Answers: []interface{}{}, Authorities: []interface{}{}, Additional: []interface{}{}, diff --git a/src/zdns/types.go b/src/zdns/types.go index 309df459..bccd5525 100644 --- a/src/zdns/types.go +++ b/src/zdns/types.go @@ -20,6 +20,13 @@ import ( "github.com/zmap/zdns/src/internal/util" ) +const ( + DoHProtocol = "DoH" + DoTProtocol = "DoT" + UDPProtocol = "udp" + TCPProtocol = "tcp" +) + type transportMode int const ( diff --git a/testing/integration_tests.py b/testing/integration_tests.py index 73a88757..db7b36ac 100755 --- a/testing/integration_tests.py +++ b/testing/integration_tests.py @@ -1415,6 +1415,18 @@ def test_dnssec_option(self): break self.assertTrue(hasRRSIG, "DNSSEC option should return an RRSIG record") + def test_external_lookup_cache(self): + c = "A google.com google.com --name-servers=8.8.8.8 --threads=1" + name = "" + cmd, res = self.run_zdns_multiline_output(c, name, append_flags=False) + self.assertSuccess(res[0], cmd, "A") + self.assertSuccess(res[1], cmd, "A") + first_duration = res[0]["results"]["A"]["duration"] + second_duration = res[1]["results"]["A"]["duration"] + # a bit of a hacky test, but we're checking that if we query the same domain with the same nameserver, + # the second query has a much smaller response time than the first to show it's being cached + self.assertTrue(first_duration / 50 > second_duration, f"Second query {second_duration} should be faster than the first {first_duration}") +