Skip to content

Commit

Permalink
Merge pull request #190 from mailgun/maxim/develop
Browse files Browse the repository at this point in the history
PIP-2872: Fix data race in mxresolv
  • Loading branch information
horkhe authored Apr 19, 2024
2 parents a2fabf4 + 60e22fd commit 20e430d
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 78 deletions.
136 changes: 81 additions & 55 deletions mxresolv/mxresolv.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"net"
"sort"
"strings"
"sync"
"time"
"unicode"
_ "unsafe" // For go:linkname
Expand All @@ -26,8 +27,9 @@ var (
errNoValidMXHosts = errors.New("no valid MX hosts")
lookupResultCache *collections.LRUCache

// defaultSeed allows the seed function to be patched in tests using SetDeterministic()
defaultRand = newRand
// randomizer allows the seed function to be patched in tests using SetDeterministic()
randomizerMu sync.Mutex
randomizer = rand.New(rand.NewSource(time.Now().UnixNano()))

// Resolver is exposed to be patched in tests
Resolver = net.DefaultResolver
Expand All @@ -37,23 +39,20 @@ func init() {
lookupResultCache = collections.NewLRUCache(cacheSize)
}

func newRand() *rand.Rand {
return rand.New(rand.NewSource(time.Now().UnixNano()))
}

// Lookup performs a DNS lookup of MX records for the specified hostname. It
// returns a prioritised list of MX hostnames, where hostnames with the same
// priority are shuffled. If the second returned value is true, then the host
// does not have explicit MX records, and its A record is returned instead.
//
// It uses an LRU cache with a timeout to reduce the number of network requests.
func Lookup(ctx context.Context, hostname string) (retMxHosts []string, retImplicit bool, reterr error) {
if obj, ok := lookupResultCache.Get(hostname); ok {
cached := obj.(lookupResult)
if len(cached.mxRecords) != 0 {
return shuffleMXRecords(cached.mxRecords), cached.implicit, cached.err
if cachedVal, ok := lookupResultCache.Get(hostname); ok {
cachedLookupResult := cachedVal.(lookupResult)
if cachedLookupResult.shuffled {
reshuffledMXHosts, _ := shuffleMXRecords(cachedLookupResult.mxRecords)
return reshuffledMXHosts, cachedLookupResult.implicit, cachedLookupResult.err
}
return cached.mxHosts, cached.implicit, cached.err
return cachedLookupResult.mxHosts, cachedLookupResult.implicit, cachedLookupResult.err
}

asciiHostname, err := ensureASCII(hostname)
Expand All @@ -67,25 +66,25 @@ func Lookup(ctx context.Context, hostname string) (retMxHosts []string, retImpli
return nil, false, errors.WithStack(err)
}
var netDNSError *net.DNSError
if errors.As(err, &netDNSError) && netDNSError.Err == "no such host" {
if errors.As(err, &netDNSError) && netDNSError.IsNotFound {
if _, err := Resolver.LookupIPAddr(ctx, asciiHostname); err != nil {
return cacheAndReturn(hostname, nil, nil, false, errors.WithStack(err))
return cacheAndReturn(hostname, nil, nil, false, false, errors.WithStack(err))
}
return cacheAndReturn(hostname, []string{asciiHostname}, nil, true, nil)
return cacheAndReturn(hostname, []string{asciiHostname}, nil, false, true, nil)
}
if mxRecords == nil {
return cacheAndReturn(hostname, nil, nil, false, errors.WithStack(err))
return cacheAndReturn(hostname, nil, nil, false, false, errors.WithStack(err))
}
}
// Check for "Null MX" record (https://tools.ietf.org/html/rfc7505).
if len(mxRecords) == 1 {
if mxRecords[0].Host == "." {
return cacheAndReturn(hostname, nil, nil, false, errNullMXRecord)
return cacheAndReturn(hostname, nil, nil, false, false, errNullMXRecord)
}
// 0.0.0.0 is not really a "Null MX" record, but some people apparently
// have never heard of RFC7505 and configure it this way.
if strings.HasPrefix(mxRecords[0].Host, "0.0.0.0") {
return cacheAndReturn(hostname, nil, nil, false, errNullMXRecord)
return cacheAndReturn(hostname, nil, nil, false, false, errNullMXRecord)
}
}
// Normalize returned hostnames: drop trailing '.' and lowercase.
Expand All @@ -101,19 +100,24 @@ func Lookup(ctx context.Context, hostname string) (retMxHosts []string, retImpli
return mxRecords[i].Pref < mxRecords[j].Pref ||
(mxRecords[i].Pref == mxRecords[j].Pref && mxRecords[i].Host < mxRecords[j].Host)
})
mxHosts := shuffleMXRecords(mxRecords)
mxHosts, shuffled := shuffleMXRecords(mxRecords)
if len(mxHosts) == 0 {
return cacheAndReturn(hostname, nil, nil, false, errNoValidMXHosts)
return cacheAndReturn(hostname, nil, nil, false, false, errNoValidMXHosts)
}
return cacheAndReturn(hostname, mxHosts, mxRecords, false, nil)
return cacheAndReturn(hostname, mxHosts, mxRecords, shuffled, false, nil)
}

// SetDeterministic sets rand to deterministic seed for testing, and is not Thread-Safe
func SetDeterministic() func() {
r := rand.New(rand.NewSource(1))
defaultRand = func() *rand.Rand { return r }
// SetDeterministicInTests sets rand to deterministic seed for testing, and is
// not Thread-Safe.
func SetDeterministicInTests() func() {
randomizerMu.Lock()
old := randomizer
randomizer = rand.New(rand.NewSource(1))
randomizerMu.Unlock()
return func() {
defaultRand = newRand
randomizerMu.Lock()
randomizer = old
randomizerMu.Unlock()
}
}

Expand All @@ -122,43 +126,64 @@ func ResetCache() {
lookupResultCache = collections.NewLRUCache(1000)
}

func shuffleMXRecords(mxRecords []*net.MX) []string {
r := defaultRand()

// Shuffle the hosts within the preference groups
begin := 0
for i := 0; i <= len(mxRecords); i++ {
// If we are on the last record shuffle the last preference group
if i == len(mxRecords) {
group := mxRecords[begin:i]
r.Shuffle(len(group), func(i, j int) {
group[i], group[j] = group[j], group[i]
})
break
}

// After finding the end of a preference group, shuffle it
if mxRecords[begin].Pref != mxRecords[i].Pref {
group := mxRecords[begin:i]
r.Shuffle(len(group), func(i, j int) {
group[i], group[j] = group[j], group[i]
})
begin = i
}
}

// Make a hostname list, but skip non-ASCII names, that cause issues.
mxHosts := make([]string, 0, len(mxRecords))
func shuffleMXRecords(mxRecords []*net.MX) ([]string, bool) {
// Shuffle the hosts within the preference groups.
var (
mxHosts []string
groupBegin = 0
groupEnd = 0
groupPref uint16
shuffled = false
)
for _, mxRecord := range mxRecords {
// If a hostname has non-ASCII characters then ignore it, for it is
// a kind of human error that we saw in production.
if !isASCII(mxRecord.Host) {
continue
}
// Just being overly cautious, so checking for empty values.
if mxRecord.Host == "" {
continue
}
// If it is the first valid record in the set, then allocate a slice
// for MX hosts and put it there.
if mxHosts == nil {
mxHosts = make([]string, 0, len(mxRecords))
mxHosts = append(mxHosts, mxRecord.Host)
groupPref = mxRecord.Pref
groupEnd = 1
continue
}
// Put the next valid record to the slice.
mxHosts = append(mxHosts, mxRecord.Host)
// If the added host has the same preference as the first one in the
// current group, then continue the MX record set traversal.
if groupPref == mxRecord.Pref {
groupEnd++
continue
}
// After finding the end of the current preference group, shuffle it.
if groupEnd-groupBegin > 1 {
shuffleHosts(mxHosts[groupBegin:groupEnd])
shuffled = true
}
// Set up the next preference group.
groupBegin = groupEnd
groupEnd++
groupPref = mxRecord.Pref
}
return mxHosts
// Shuffle the last preference group, if there is one.
if groupEnd-groupBegin > 1 {
shuffleHosts(mxHosts[groupBegin:groupEnd])
shuffled = true
}
return mxHosts, shuffled
}

func shuffleHosts(hosts []string) {
randomizerMu.Lock()
randomizer.Shuffle(len(hosts), func(i, j int) { hosts[i], hosts[j] = hosts[j], hosts[i] })
randomizerMu.Unlock()
}

func ensureASCII(hostname string) (string, error) {
Expand All @@ -184,12 +209,13 @@ func isASCII(s string) bool {
type lookupResult struct {
mxRecords []*net.MX
mxHosts []string
shuffled bool
implicit bool
err error
}

func cacheAndReturn(hostname string, mxHosts []string, mxRecords []*net.MX, implicit bool, err error) (retMxHosts []string, retImplicit bool, reterr error) {
lookupResultCache.AddWithTTL(hostname, lookupResult{mxHosts: mxHosts, mxRecords: mxRecords, implicit: implicit, err: err}, cacheTTL)
func cacheAndReturn(hostname string, mxHosts []string, mxRecords []*net.MX, shuffled, implicit bool, err error) (retMxHosts []string, retImplicit bool, reterr error) {
lookupResultCache.AddWithTTL(hostname, lookupResult{mxHosts: mxHosts, mxRecords: mxRecords, shuffled: shuffled, implicit: implicit, err: err}, cacheTTL)
return mxHosts, implicit, err
}

Expand Down
53 changes: 30 additions & 23 deletions mxresolv/mxresolv_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"math/rand"
"net"
"os"
"reflect"
"regexp"
"sort"
"testing"
Expand Down Expand Up @@ -166,7 +167,7 @@ func TestLookup(t *testing.T) {
outImplicitMX: false,
}} {
t.Run(tc.inDomainName, func(t *testing.T) {
defer mxresolv.SetDeterministic()()
defer mxresolv.SetDeterministicInTests()()

// When
ctx, cancel := context.WithTimeout(context.Background(), 3*clock.Second)
Expand All @@ -181,7 +182,7 @@ func TestLookup(t *testing.T) {
}

func TestLookupRegression(t *testing.T) {
defer mxresolv.SetDeterministic()()
defer mxresolv.SetDeterministicInTests()()
mxresolv.ResetCache()

// When
Expand All @@ -190,51 +191,57 @@ func TestLookupRegression(t *testing.T) {

mxHosts, explictMX, err := mxresolv.Lookup(ctx, "test-mx.definbox.com")
// Then
assert.NoError(t, err)
require.NoError(t, err)
assert.Equal(t, []string{
"mxa.definbox.com", "mxi.definbox.com", "mxe.definbox.com", "mxc.definbox.com",
"mxb.definbox.com", "mxf.definbox.com", "mxh.definbox.com", "mxd.definbox.com",
"mxg.definbox.com",
/* 1 */ "mxa.definbox.com", "mxi.definbox.com", "mxe.definbox.com",
/* 2 */ "mxc.definbox.com",
/* 3 */ "mxb.definbox.com", "mxf.definbox.com", "mxh.definbox.com", "mxd.definbox.com", "mxg.definbox.com",
}, mxHosts)
assert.Equal(t, false, explictMX)

// The second lookup returns the cached result, the cached result is shuffled.
mxHosts, explictMX, err = mxresolv.Lookup(ctx, "test-mx.definbox.com")
assert.NoError(t, err)
require.NoError(t, err)
assert.Equal(t, []string{
"mxi.definbox.com", "mxe.definbox.com", "mxa.definbox.com", "mxc.definbox.com",
"mxg.definbox.com", "mxh.definbox.com", "mxd.definbox.com", "mxf.definbox.com",
"mxb.definbox.com",
/* 1 */ "mxe.definbox.com", "mxi.definbox.com", "mxa.definbox.com",
/* 2 */ "mxc.definbox.com",
/* 3 */ "mxh.definbox.com", "mxf.definbox.com", "mxg.definbox.com", "mxd.definbox.com", "mxb.definbox.com",
}, mxHosts)
assert.Equal(t, false, explictMX)

mxHosts, _, err = mxresolv.Lookup(ctx, "definbox.com")
assert.NoError(t, err)
require.NoError(t, err)
assert.Equal(t, []string{"mxb.ninomail.com", "mxa.ninomail.com"}, mxHosts)

// Should always prefer mxb over mxa since mxb has a lower pref than mxa
for i := 0; i < 100; i++ {
mxHosts, _, err = mxresolv.Lookup(ctx, "prefer.example.com")
assert.NoError(t, err)
require.NoError(t, err)
assert.Equal(t, []string{"mxb.example.com", "mxa.example.com"}, mxHosts)
}

// Should randomly order mxa and mxb while mxc should always be last
// Should randomly order mxa and mxb. We make lookup 10 times and make sure
// that the returned result is not always the same.
mxHosts, _, err = mxresolv.Lookup(ctx, "prefer3.example.com")
assert.NoError(t, err)
require.NoError(t, err)
assert.Equal(t, []string{"mxb.example.com", "mxa.example.com", "mxc.example.com"}, mxHosts)
sameCount := 0
for i := 0; i < 10; i++ {
mxHosts2, _, err := mxresolv.Lookup(ctx, "prefer3.example.com")
assert.NoError(t, err)
if reflect.DeepEqual(mxHosts, mxHosts2) {
sameCount++
}
}
assert.Less(t, sameCount, 10)

mxHosts, _, err = mxresolv.Lookup(ctx, "prefer3.example.com")
assert.NoError(t, err)
assert.Equal(t, []string{"mxa.example.com", "mxb.example.com", "mxc.example.com"}, mxHosts)

// 'mxc.example.com' should always be last as it has a different priority than the other two.
// mxc.example.com should always be last as it has a different priority,
// than the other two.
for i := 0; i < 100; i++ {
mxHosts, _, err = mxresolv.Lookup(ctx, "prefer3.example.com")
assert.NoError(t, err)
require.NoError(t, err)
assert.Equal(t, "mxc.example.com", mxHosts[2])
}

}

func TestLookupError(t *testing.T) {
Expand Down Expand Up @@ -291,15 +298,15 @@ func TestLookupError(t *testing.T) {
}
}

// Shuffling only does not cross preference group boundaries.
// Shuffling does not cross preference group boundaries.
//
// Preference groups are:
//
// 1: mxa.definbox.com, mxe.definbox.com, mxi.definbox.com
// 2: mxc.definbox.com
// 3: mxb.definbox.com, mxd.definbox.com, mxf.definbox.com, mxg.definbox.com, mxh.definbox.com
func TestLookupShuffle(t *testing.T) {
defer mxresolv.SetDeterministic()()
defer mxresolv.SetDeterministicInTests()()

// When
ctx, cancel := context.WithTimeout(context.Background(), 3*clock.Second)
Expand Down

0 comments on commit 20e430d

Please sign in to comment.