From 56e946cec902605697c15d7d8897ca4dc8ab7f4a Mon Sep 17 00:00:00 2001 From: caffix Date: Wed, 24 Jul 2024 11:44:22 -0400 Subject: [PATCH] fixed some unit tests --- rate.go | 16 ++++++++++------ rate_test.go | 6 ++++-- resolvers_test.go | 2 +- selector.go | 8 +++++++- 4 files changed, 22 insertions(+), 10 deletions(-) diff --git a/rate.go b/rate.go index 9ae8877..afd97e0 100644 --- a/rate.go +++ b/rate.go @@ -15,10 +15,11 @@ import ( ) const ( - maxQPSPerNameserver = 250 - numIntervalSeconds = 2 - rateUpdateInterval = numIntervalSeconds * time.Second - maxTimeoutPercentage = 0.5 + startQPSPerNameserver = 15 + maxQPSPerNameserver = 250 + numIntervalSeconds = 2 + rateUpdateInterval = numIntervalSeconds * time.Second + maxTimeoutPercentage = 0.5 ) type rateTrack struct { @@ -52,8 +53,8 @@ func NewRateTracker() *RateTracker { func newRateTrack() *rateTrack { return &rateTrack{ - qps: maxQPSPerNameserver, - rate: ratelimit.New(maxQPSPerNameserver), + qps: startQPSPerNameserver, + rate: ratelimit.New(startQPSPerNameserver), } } @@ -135,6 +136,9 @@ func (rt *rateTrack) update() { } } else { rt.qps += 1 + if rt.qps > maxQPSPerNameserver { + rt.qps = maxQPSPerNameserver + } } // update the QPS rate limiter and reset counters rt.rate = ratelimit.New(rt.qps) diff --git a/rate_test.go b/rate_test.go index 67ed029..4d37cae 100644 --- a/rate_test.go +++ b/rate_test.go @@ -21,12 +21,14 @@ func TestUpdateRateLimiters(t *testing.T) { tracker.Lock() qps := tracker.qps tracker.Unlock() - num := qps / 2 + num := qps / 3 // set a large number of timeouts for i := 0; i < num; i++ { rt.Success(domain) } - for i := 0; i < num; i++ { + + max := tracker.qps - num + for i := 0; i < max; i++ { rt.Timeout(domain) } time.Sleep(rateUpdateInterval + (rateUpdateInterval / 2)) diff --git a/resolvers_test.go b/resolvers_test.go index 90ec5b1..03d184b 100644 --- a/resolvers_test.go +++ b/resolvers_test.go @@ -29,7 +29,7 @@ func TestInitializeResolver(t *testing.T) { func TestSetTimeout(t *testing.T) { r := NewResolvers() - _ = r.AddResolvers(1000, "8.8.8.8") + _ = r.AddResolvers(maxQPSPerNameserver, "8.8.8.8") defer r.Stop() timeout := 2 * time.Second diff --git a/selector.go b/selector.go index db7764b..0f58cc0 100644 --- a/selector.go +++ b/selector.go @@ -44,8 +44,14 @@ func (r *randomSelector) GetResolver() *resolver { r.Lock() defer r.Unlock() + if l := len(r.list); l == 0 { + return nil + } else if l == 1 { + return r.list[0] + } + var chosen *resolver - sel := rand.Intn(len(r.list) + 1) + sel := rand.Intn(len(r.list)) loop: for _, res := range r.list[sel:] { select {