Skip to content

Commit

Permalink
Merge pull request #175 from mailgun/thrawn/mxresolv-rand
Browse files Browse the repository at this point in the history
mxresolv.Lookup() now properly randomized hosts in some situations
  • Loading branch information
thrawn01 authored Jun 21, 2023
2 parents ccaeadc + 43e77ef commit fa93174
Show file tree
Hide file tree
Showing 4 changed files with 333 additions and 104 deletions.
3 changes: 3 additions & 0 deletions .golangci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,9 @@ issues:
- noctx
- path: '_test\.go$'
text: "unnamedResult:"
- path: '.*mxresolv.*'
linters:
- gosec


run:
Expand Down
43 changes: 43 additions & 0 deletions mxresolv/dns_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
package mxresolv_test

import (
"net"
"sync"

"github.com/foxcpp/go-mockdns"
)

type MockDNS struct {
Server *mockdns.Server
mu sync.Mutex
}

func SpawnMockDNS(zones map[string]mockdns.Zone) (*MockDNS, error) {
server, err := mockdns.NewServerWithLogger(zones, nullLogger{}, false)
if err != nil {
return nil, err
}
return &MockDNS{
Server: server,
}, nil
}

func (f *MockDNS) Stop() {
_ = f.Server.Close()
}

func (f *MockDNS) Patch(r *net.Resolver) {
f.mu.Lock()
defer f.mu.Unlock()
f.Server.PatchNet(r)
}

func (f *MockDNS) UnPatch(r *net.Resolver) {
f.mu.Lock()
defer f.mu.Unlock()
mockdns.UnpatchNet(r)
}

type nullLogger struct{}

func (l nullLogger) Printf(_ string, _ ...interface{}) {}
67 changes: 47 additions & 20 deletions mxresolv/mxresolv.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"net"
"sort"
"strings"
"time"
"unicode"
_ "unsafe" // For go:linkname

Expand All @@ -23,21 +24,23 @@ const (
var (
errNullMXRecord = errors.New("domain accepts no mail")
errNoValidMXHosts = errors.New("no valid MX hosts")

lookupResultCache *collections.LRUCache

// It is modified only in tests to make them deterministic.
shuffle = true
// defaultSeed allows the seed function to be patched in tests using SetDeterministic()
defaultRand = newRand

// DefaultResolver is exposed mainly to be patched in tests to access a
// mock DNS server github.com/foxcpp/go-mockdns.
DefaultResolver = net.DefaultResolver
// Resolver is exposed to be patched in tests
Resolver = net.DefaultResolver
)

func init() {
lookupResultCache = collections.NewLRUCache(cacheSize)
}

func newRand() *rand.Rand {
return rand.New(rand.NewSource(time.Now().UnixNano()))
}

// Lookup performs a DNS lookup of MX records for the specified hostname. It
// returns a prioritised list of MX hostnames, where hostnames with the same
// priority are shuffled. If the second returned value is true, then the host
Expand All @@ -57,15 +60,15 @@ func Lookup(ctx context.Context, hostname string) (retMxHosts []string, retImpli
if err != nil {
return nil, false, errors.Wrap(err, "invalid hostname")
}
mxRecords, err := lookupMX(DefaultResolver, ctx, asciiHostname)
mxRecords, err := lookupMX(Resolver, ctx, asciiHostname)
if err != nil {
var timeouter interface{ Timeout() bool }
if errors.As(err, &timeouter) && timeouter.Timeout() {
return nil, false, errors.WithStack(err)
}
var netDNSError *net.DNSError
if errors.As(err, &netDNSError) && netDNSError.Err == "no such host" {
if _, err := DefaultResolver.LookupIPAddr(ctx, asciiHostname); err != nil {
if _, err := Resolver.LookupIPAddr(ctx, asciiHostname); err != nil {
return cacheAndReturn(hostname, nil, nil, false, errors.WithStack(err))
}
return cacheAndReturn(hostname, []string{asciiHostname}, nil, true, nil)
Expand Down Expand Up @@ -105,21 +108,45 @@ func Lookup(ctx context.Context, hostname string) (retMxHosts []string, retImpli
return cacheAndReturn(hostname, mxHosts, mxRecords, false, nil)
}

// SetDeterministic sets rand to deterministic seed for testing, and is not Thread-Safe
func SetDeterministic() func() {
r := rand.New(rand.NewSource(1))
defaultRand = func() *rand.Rand { return r }
return func() {
defaultRand = newRand
}
}

// ResetCache clears the cache for use in tests, and is not Thread-Safe
func ResetCache() {
lookupResultCache = collections.NewLRUCache(1000)
}

func shuffleMXRecords(mxRecords []*net.MX) []string {
// Shuffle records within preference groups unless disabled in tests.
if shuffle {
mxRecordCount := len(mxRecords) - 1
groupBegin := 0
for i := 1; i <= mxRecordCount; i++ {
if mxRecords[i].Pref != mxRecords[groupBegin].Pref || i == mxRecordCount {
groupSlice := mxRecords[groupBegin:i]
rand.Shuffle(len(groupSlice), func(i, j int) {
groupSlice[i], groupSlice[j] = groupSlice[j], groupSlice[i]
})
groupBegin = i
}
r := defaultRand()

// Shuffle the hosts within the preference groups
begin := 0
for i := 0; i <= len(mxRecords); i++ {
// If we are on the last record shuffle the last preference group
if i == len(mxRecords) {
group := mxRecords[begin:i]
r.Shuffle(len(group), func(i, j int) {
group[i], group[j] = group[j], group[i]
})
break
}

// After finding the end of a preference group, shuffle it
if mxRecords[begin].Pref != mxRecords[i].Pref {
group := mxRecords[begin:i]
r.Shuffle(len(group), func(i, j int) {
group[i], group[j] = group[j], group[i]
})
begin = i
}
}

// Make a hostname list, but skip non-ASCII names, that cause issues.
mxHosts := make([]string, 0, len(mxRecords))
for _, mxRecord := range mxRecords {
Expand Down
Loading

0 comments on commit fa93174

Please sign in to comment.