diff --git a/mxresolv/mxresolv.go b/mxresolv/mxresolv.go index cdb3256..77bc3fa 100644 --- a/mxresolv/mxresolv.go +++ b/mxresolv/mxresolv.go @@ -6,6 +6,7 @@ import ( "net" "sort" "strings" + "sync" "time" "unicode" _ "unsafe" // For go:linkname @@ -26,8 +27,9 @@ var ( errNoValidMXHosts = errors.New("no valid MX hosts") lookupResultCache *collections.LRUCache - // defaultSeed allows the seed function to be patched in tests using SetDeterministic() - defaultRand = newRand + // randomizer allows the seed function to be patched in tests using SetDeterministic() + randomizerMu sync.Mutex + randomizer = rand.New(rand.NewSource(time.Now().UnixNano())) // Resolver is exposed to be patched in tests Resolver = net.DefaultResolver @@ -37,10 +39,6 @@ func init() { lookupResultCache = collections.NewLRUCache(cacheSize) } -func newRand() *rand.Rand { - return rand.New(rand.NewSource(time.Now().UnixNano())) -} - // Lookup performs a DNS lookup of MX records for the specified hostname. It // returns a prioritised list of MX hostnames, where hostnames with the same // priority are shuffled. If the second returned value is true, then the host @@ -48,12 +46,13 @@ func newRand() *rand.Rand { // // It uses an LRU cache with a timeout to reduce the number of network requests. func Lookup(ctx context.Context, hostname string) (retMxHosts []string, retImplicit bool, reterr error) { - if obj, ok := lookupResultCache.Get(hostname); ok { - cached := obj.(lookupResult) - if len(cached.mxRecords) != 0 { - return shuffleMXRecords(cached.mxRecords), cached.implicit, cached.err + if cachedVal, ok := lookupResultCache.Get(hostname); ok { + cachedLookupResult := cachedVal.(lookupResult) + if cachedLookupResult.shuffled { + reshuffledMXHosts, _ := shuffleMXRecords(cachedLookupResult.mxRecords) + return reshuffledMXHosts, cachedLookupResult.implicit, cachedLookupResult.err } - return cached.mxHosts, cached.implicit, cached.err + return cachedLookupResult.mxHosts, cachedLookupResult.implicit, cachedLookupResult.err } asciiHostname, err := ensureASCII(hostname) @@ -67,25 +66,25 @@ func Lookup(ctx context.Context, hostname string) (retMxHosts []string, retImpli return nil, false, errors.WithStack(err) } var netDNSError *net.DNSError - if errors.As(err, &netDNSError) && netDNSError.Err == "no such host" { + if errors.As(err, &netDNSError) && netDNSError.IsNotFound { if _, err := Resolver.LookupIPAddr(ctx, asciiHostname); err != nil { - return cacheAndReturn(hostname, nil, nil, false, errors.WithStack(err)) + return cacheAndReturn(hostname, nil, nil, false, false, errors.WithStack(err)) } - return cacheAndReturn(hostname, []string{asciiHostname}, nil, true, nil) + return cacheAndReturn(hostname, []string{asciiHostname}, nil, false, true, nil) } if mxRecords == nil { - return cacheAndReturn(hostname, nil, nil, false, errors.WithStack(err)) + return cacheAndReturn(hostname, nil, nil, false, false, errors.WithStack(err)) } } // Check for "Null MX" record (https://tools.ietf.org/html/rfc7505). if len(mxRecords) == 1 { if mxRecords[0].Host == "." { - return cacheAndReturn(hostname, nil, nil, false, errNullMXRecord) + return cacheAndReturn(hostname, nil, nil, false, false, errNullMXRecord) } // 0.0.0.0 is not really a "Null MX" record, but some people apparently // have never heard of RFC7505 and configure it this way. if strings.HasPrefix(mxRecords[0].Host, "0.0.0.0") { - return cacheAndReturn(hostname, nil, nil, false, errNullMXRecord) + return cacheAndReturn(hostname, nil, nil, false, false, errNullMXRecord) } } // Normalize returned hostnames: drop trailing '.' and lowercase. @@ -101,19 +100,24 @@ func Lookup(ctx context.Context, hostname string) (retMxHosts []string, retImpli return mxRecords[i].Pref < mxRecords[j].Pref || (mxRecords[i].Pref == mxRecords[j].Pref && mxRecords[i].Host < mxRecords[j].Host) }) - mxHosts := shuffleMXRecords(mxRecords) + mxHosts, shuffled := shuffleMXRecords(mxRecords) if len(mxHosts) == 0 { - return cacheAndReturn(hostname, nil, nil, false, errNoValidMXHosts) + return cacheAndReturn(hostname, nil, nil, false, false, errNoValidMXHosts) } - return cacheAndReturn(hostname, mxHosts, mxRecords, false, nil) + return cacheAndReturn(hostname, mxHosts, mxRecords, shuffled, false, nil) } -// SetDeterministic sets rand to deterministic seed for testing, and is not Thread-Safe -func SetDeterministic() func() { - r := rand.New(rand.NewSource(1)) - defaultRand = func() *rand.Rand { return r } +// SetDeterministicInTests sets rand to deterministic seed for testing, and is +// not Thread-Safe. +func SetDeterministicInTests() func() { + randomizerMu.Lock() + old := randomizer + randomizer = rand.New(rand.NewSource(1)) + randomizerMu.Unlock() return func() { - defaultRand = newRand + randomizerMu.Lock() + randomizer = old + randomizerMu.Unlock() } } @@ -122,43 +126,64 @@ func ResetCache() { lookupResultCache = collections.NewLRUCache(1000) } -func shuffleMXRecords(mxRecords []*net.MX) []string { - r := defaultRand() - - // Shuffle the hosts within the preference groups - begin := 0 - for i := 0; i <= len(mxRecords); i++ { - // If we are on the last record shuffle the last preference group - if i == len(mxRecords) { - group := mxRecords[begin:i] - r.Shuffle(len(group), func(i, j int) { - group[i], group[j] = group[j], group[i] - }) - break - } - - // After finding the end of a preference group, shuffle it - if mxRecords[begin].Pref != mxRecords[i].Pref { - group := mxRecords[begin:i] - r.Shuffle(len(group), func(i, j int) { - group[i], group[j] = group[j], group[i] - }) - begin = i - } - } - - // Make a hostname list, but skip non-ASCII names, that cause issues. - mxHosts := make([]string, 0, len(mxRecords)) +func shuffleMXRecords(mxRecords []*net.MX) ([]string, bool) { + // Shuffle the hosts within the preference groups. + var ( + mxHosts []string + groupBegin = 0 + groupEnd = 0 + groupPref uint16 + shuffled = false + ) for _, mxRecord := range mxRecords { + // If a hostname has non-ASCII characters then ignore it, for it is + // a kind of human error that we saw in production. if !isASCII(mxRecord.Host) { continue } + // Just being overly cautious, so checking for empty values. if mxRecord.Host == "" { continue } + // If it is the first valid record in the set, then allocate a slice + // for MX hosts and put it there. + if mxHosts == nil { + mxHosts = make([]string, 0, len(mxRecords)) + mxHosts = append(mxHosts, mxRecord.Host) + groupPref = mxRecord.Pref + groupEnd = 1 + continue + } + // Put the next valid record to the slice. mxHosts = append(mxHosts, mxRecord.Host) + // If the added host has the same preference as the first one in the + // current group, then continue the MX record set traversal. + if groupPref == mxRecord.Pref { + groupEnd++ + continue + } + // After finding the end of the current preference group, shuffle it. + if groupEnd-groupBegin > 1 { + shuffleHosts(mxHosts[groupBegin:groupEnd]) + shuffled = true + } + // Set up the next preference group. + groupBegin = groupEnd + groupEnd++ + groupPref = mxRecord.Pref } - return mxHosts + // Shuffle the last preference group, if there is one. + if groupEnd-groupBegin > 1 { + shuffleHosts(mxHosts[groupBegin:groupEnd]) + shuffled = true + } + return mxHosts, shuffled +} + +func shuffleHosts(hosts []string) { + randomizerMu.Lock() + randomizer.Shuffle(len(hosts), func(i, j int) { hosts[i], hosts[j] = hosts[j], hosts[i] }) + randomizerMu.Unlock() } func ensureASCII(hostname string) (string, error) { @@ -184,12 +209,13 @@ func isASCII(s string) bool { type lookupResult struct { mxRecords []*net.MX mxHosts []string + shuffled bool implicit bool err error } -func cacheAndReturn(hostname string, mxHosts []string, mxRecords []*net.MX, implicit bool, err error) (retMxHosts []string, retImplicit bool, reterr error) { - lookupResultCache.AddWithTTL(hostname, lookupResult{mxHosts: mxHosts, mxRecords: mxRecords, implicit: implicit, err: err}, cacheTTL) +func cacheAndReturn(hostname string, mxHosts []string, mxRecords []*net.MX, shuffled, implicit bool, err error) (retMxHosts []string, retImplicit bool, reterr error) { + lookupResultCache.AddWithTTL(hostname, lookupResult{mxHosts: mxHosts, mxRecords: mxRecords, shuffled: shuffled, implicit: implicit, err: err}, cacheTTL) return mxHosts, implicit, err } diff --git a/mxresolv/mxresolv_test.go b/mxresolv/mxresolv_test.go index c12a654..8e95de3 100644 --- a/mxresolv/mxresolv_test.go +++ b/mxresolv/mxresolv_test.go @@ -7,6 +7,7 @@ import ( "math/rand" "net" "os" + "reflect" "regexp" "sort" "testing" @@ -166,7 +167,7 @@ func TestLookup(t *testing.T) { outImplicitMX: false, }} { t.Run(tc.inDomainName, func(t *testing.T) { - defer mxresolv.SetDeterministic()() + defer mxresolv.SetDeterministicInTests()() // When ctx, cancel := context.WithTimeout(context.Background(), 3*clock.Second) @@ -181,7 +182,7 @@ func TestLookup(t *testing.T) { } func TestLookupRegression(t *testing.T) { - defer mxresolv.SetDeterministic()() + defer mxresolv.SetDeterministicInTests()() mxresolv.ResetCache() // When @@ -190,51 +191,57 @@ func TestLookupRegression(t *testing.T) { mxHosts, explictMX, err := mxresolv.Lookup(ctx, "test-mx.definbox.com") // Then - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, []string{ - "mxa.definbox.com", "mxi.definbox.com", "mxe.definbox.com", "mxc.definbox.com", - "mxb.definbox.com", "mxf.definbox.com", "mxh.definbox.com", "mxd.definbox.com", - "mxg.definbox.com", + /* 1 */ "mxa.definbox.com", "mxi.definbox.com", "mxe.definbox.com", + /* 2 */ "mxc.definbox.com", + /* 3 */ "mxb.definbox.com", "mxf.definbox.com", "mxh.definbox.com", "mxd.definbox.com", "mxg.definbox.com", }, mxHosts) assert.Equal(t, false, explictMX) // The second lookup returns the cached result, the cached result is shuffled. mxHosts, explictMX, err = mxresolv.Lookup(ctx, "test-mx.definbox.com") - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, []string{ - "mxi.definbox.com", "mxe.definbox.com", "mxa.definbox.com", "mxc.definbox.com", - "mxg.definbox.com", "mxh.definbox.com", "mxd.definbox.com", "mxf.definbox.com", - "mxb.definbox.com", + /* 1 */ "mxe.definbox.com", "mxi.definbox.com", "mxa.definbox.com", + /* 2 */ "mxc.definbox.com", + /* 3 */ "mxh.definbox.com", "mxf.definbox.com", "mxg.definbox.com", "mxd.definbox.com", "mxb.definbox.com", }, mxHosts) assert.Equal(t, false, explictMX) mxHosts, _, err = mxresolv.Lookup(ctx, "definbox.com") - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, []string{"mxb.ninomail.com", "mxa.ninomail.com"}, mxHosts) // Should always prefer mxb over mxa since mxb has a lower pref than mxa for i := 0; i < 100; i++ { mxHosts, _, err = mxresolv.Lookup(ctx, "prefer.example.com") - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, []string{"mxb.example.com", "mxa.example.com"}, mxHosts) } - // Should randomly order mxa and mxb while mxc should always be last + // Should randomly order mxa and mxb. We make lookup 10 times and make sure + // that the returned result is not always the same. mxHosts, _, err = mxresolv.Lookup(ctx, "prefer3.example.com") - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, []string{"mxb.example.com", "mxa.example.com", "mxc.example.com"}, mxHosts) + sameCount := 0 + for i := 0; i < 10; i++ { + mxHosts2, _, err := mxresolv.Lookup(ctx, "prefer3.example.com") + assert.NoError(t, err) + if reflect.DeepEqual(mxHosts, mxHosts2) { + sameCount++ + } + } + assert.Less(t, sameCount, 10) - mxHosts, _, err = mxresolv.Lookup(ctx, "prefer3.example.com") - assert.NoError(t, err) - assert.Equal(t, []string{"mxa.example.com", "mxb.example.com", "mxc.example.com"}, mxHosts) - - // 'mxc.example.com' should always be last as it has a different priority than the other two. + // mxc.example.com should always be last as it has a different priority, + // than the other two. for i := 0; i < 100; i++ { mxHosts, _, err = mxresolv.Lookup(ctx, "prefer3.example.com") - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, "mxc.example.com", mxHosts[2]) } - } func TestLookupError(t *testing.T) { @@ -291,7 +298,7 @@ func TestLookupError(t *testing.T) { } } -// Shuffling only does not cross preference group boundaries. +// Shuffling does not cross preference group boundaries. // // Preference groups are: // @@ -299,7 +306,7 @@ func TestLookupError(t *testing.T) { // 2: mxc.definbox.com // 3: mxb.definbox.com, mxd.definbox.com, mxf.definbox.com, mxg.definbox.com, mxh.definbox.com func TestLookupShuffle(t *testing.T) { - defer mxresolv.SetDeterministic()() + defer mxresolv.SetDeterministicInTests()() // When ctx, cancel := context.WithTimeout(context.Background(), 3*clock.Second)