From 52f5b91df416d35a08a39bda9adf89841de82cb0 Mon Sep 17 00:00:00 2001 From: censhin Date: Thu, 3 Mar 2022 13:50:59 -0500 Subject: [PATCH] refactor(mxresolv/mxresolv.go): adding DefaultResolver to the package --- mxresolv/mxresolv.go | 14 +++++--------- mxresolv/mxresolv_test.go | 12 ++++++------ 2 files changed, 11 insertions(+), 15 deletions(-) diff --git a/mxresolv/mxresolv.go b/mxresolv/mxresolv.go index 931548eb..a604695b 100644 --- a/mxresolv/mxresolv.go +++ b/mxresolv/mxresolv.go @@ -27,6 +27,8 @@ var ( // It is modified only in tests to make them deterministic. shuffle = true + + DefaultResolver = net.DefaultResolver ) func init() { @@ -39,13 +41,7 @@ func init() { // does not have explicit MX records, and its A record is returned instead. // // It uses an LRU cache with a timeout to reduce the number of network requests. -func Lookup(ctx context.Context, hostname string, r *net.Resolver) ([]string, bool, error) { - var resolver *net.Resolver - if r == nil { - resolver = net.DefaultResolver - } else { - resolver = r - } +func Lookup(ctx context.Context, hostname string) ([]string, bool, error) { if cachedVal, ok := lookupResultCache.Get(hostname); ok { lookupResult := cachedVal.(lookupResult) return lookupResult.mxHosts, lookupResult.implicit, lookupResult.err @@ -54,7 +50,7 @@ func Lookup(ctx context.Context, hostname string, r *net.Resolver) ([]string, bo if err != nil { return nil, false, errors.Wrap(err, "invalid hostname") } - mxRecords, err := resolver.LookupMX(ctx, asciiHostname) + mxRecords, err := DefaultResolver.LookupMX(ctx, asciiHostname) if err != nil { var timeouter interface{ Timeout() bool } if errors.As(err, &timeouter) && timeouter.Timeout() { @@ -62,7 +58,7 @@ func Lookup(ctx context.Context, hostname string, r *net.Resolver) ([]string, bo } var netDNSError *net.DNSError if errors.As(err, &netDNSError) && netDNSError.Err == "no such host" { - if _, err := resolver.LookupIPAddr(ctx, asciiHostname); err != nil { + if _, err := DefaultResolver.LookupIPAddr(ctx, asciiHostname); err != nil { return cacheAndReturn(hostname, nil, false, errors.WithStack(err)) } return cacheAndReturn(hostname, []string{asciiHostname}, true, nil) diff --git a/mxresolv/mxresolv_test.go b/mxresolv/mxresolv_test.go index b96229a4..2685a502 100644 --- a/mxresolv/mxresolv_test.go +++ b/mxresolv/mxresolv_test.go @@ -58,7 +58,7 @@ func TestLookup(t *testing.T) { fmt.Printf("Test case #%d: %s, %s\n", i, tc.inDomainName, tc.desc) // When ctx, cancel := context.WithTimeout(context.Background(), 3*clock.Second) - mxHosts, explictMX, err := Lookup(ctx, tc.inDomainName, nil) + mxHosts, explictMX, err := Lookup(ctx, tc.inDomainName) cancel() // Then assert.NoError(t, err) @@ -67,7 +67,7 @@ func TestLookup(t *testing.T) { // The second lookup returns the cached result, that only shows on the // coverage report. - mxHosts, explictMX, err = Lookup(ctx, tc.inDomainName, nil) + mxHosts, explictMX, err = Lookup(ctx, tc.inDomainName) assert.NoError(t, err) assert.Equal(t, tc.outMXHosts, mxHosts) assert.Equal(t, tc.outImplicitMX, explictMX) @@ -96,14 +96,14 @@ func TestLookupError(t *testing.T) { fmt.Printf("Test case #%d: %s, %s\n", i, tc.inHostname, tc.desc) // When ctx, cancel := context.WithTimeout(context.Background(), 3*clock.Second) - _, _, err := Lookup(ctx, tc.inHostname, nil) + _, _, err := Lookup(ctx, tc.inHostname) cancel() // Then assert.Regexp(t, regexp.MustCompile(tc.outError), err.Error()) // The second lookup returns the cached result, that only shows on the // coverage report. - _, _, err = Lookup(ctx, tc.inHostname, nil) + _, _, err = Lookup(ctx, tc.inHostname) assert.Regexp(t, regexp.MustCompile(tc.outError), err.Error()) } } @@ -121,10 +121,10 @@ func TestLookupShuffle(t *testing.T) { // When ctx, cancel := context.WithTimeout(context.Background(), 3*clock.Second) defer cancel() - shuffle1, _, err := Lookup(ctx, "test-mx.definbox.com", nil) + shuffle1, _, err := Lookup(ctx, "test-mx.definbox.com") assert.NoError(t, err) resetCache() - shuffle2, _, err := Lookup(ctx, "test-mx.definbox.com", nil) + shuffle2, _, err := Lookup(ctx, "test-mx.definbox.com") assert.NoError(t, err) // Then