diff --git a/client_dailer.go b/client_dailer.go index bc11f5a..4d54e7a 100644 --- a/client_dailer.go +++ b/client_dailer.go @@ -8,19 +8,21 @@ import ( "net/http" "net/url" "reflect" + "runtime" "sync" + "sync/atomic" "time" "unsafe" ) type UDPDialer struct { - Addr *net.UDPAddr - Timeout time.Duration - MaxIdleConns int - MaxConns int + Addr *net.UDPAddr + Timeout time.Duration + MaxConns int64 - mu sync.Mutex - conns []net.Conn + mu sync.Mutex + pool *sync.Pool + size int64 } func (d *UDPDialer) DialContext(ctx context.Context, network, addr string) (conn net.Conn, err error) { @@ -31,38 +33,39 @@ func (d *UDPDialer) DialContext(ctx context.Context, network, addr string) (conn return } -func (d *UDPDialer) get() (conn net.Conn, err error) { - d.mu.Lock() - defer d.mu.Unlock() - - count := len(d.conns) - if d.MaxConns != 0 && count > d.MaxConns { - err = ErrMaxConns +func udpDialerFinalizer(c net.Conn) { + _ = c.Close() +} - return - } - if count > 0 { - conn = d.conns[len(d.conns)-1] - d.conns = d.conns[:len(d.conns)-1] +func (d *UDPDialer) get() (conn net.Conn, err error) { + if atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&d.pool))) == nil { + d.mu.Lock() + if d.pool == nil { + pool := &sync.Pool{ + New: func() any { + c, err := net.DialUDP("udp", nil, d.Addr) + if err != nil { + return nil + } + return c + }, + } + atomic.StorePointer((*unsafe.Pointer)(unsafe.Pointer(&d.pool)), unsafe.Pointer(&pool)) + } + d.mu.Unlock() } - return + conn, _ = d.pool.Get().(net.Conn) + return conn, nil } func (d *UDPDialer) Put(conn net.Conn) { - if _, ok := conn.(*net.UDPConn); !ok { - return - } - - d.mu.Lock() - defer d.mu.Unlock() - - if (d.MaxIdleConns != 0 && len(d.conns) > d.MaxIdleConns) || (d.MaxConns != 0 && len(d.conns) > d.MaxConns) { + if atomic.AddInt64(&d.size, 1) <= d.MaxConns { + d.pool.Put(conn) + runtime.SetFinalizer(conn, udpDialerFinalizer) + } else { conn.Close() - return } - - d.conns = append(d.conns, conn) } type HTTPDialer struct { diff --git a/client_test.go b/client_test.go index ded8be6..ce10fb7 100644 --- a/client_test.go +++ b/client_test.go @@ -147,3 +147,45 @@ func TestClientLookup(t *testing.T) { } } } + +func BenchmarkPureGoNetResolver(b *testing.B) { + resolver := net.Resolver{PreferGo: true} + + b.ReportAllocs() + b.ResetTimer() + b.RunParallel(func(b *testing.PB) { + for b.Next() { + resolver.LookupNetIP(context.Background(), "ip4", "www.google.com") + } + }) +} + +func BenchmarkCGoNetResolver(b *testing.B) { + resolver := net.Resolver{PreferGo: false} + + b.ReportAllocs() + b.ResetTimer() + b.RunParallel(func(b *testing.PB) { + for b.Next() { + resolver.LookupNetIP(context.Background(), "ip4", "www.google.com") + } + }) +} + +func BenchmarkFastdnsResolver(b *testing.B) { + resolver := Client{ + Addr: "1.1.1.1:53", + Dialer: &UDPDialer{ + Addr: func() (u *net.UDPAddr) { u, _ = net.ResolveUDPAddr("udp", "1.1.1.1:53"); return }(), + MaxConns: 1000, + }, + } + + b.ReportAllocs() + b.ResetTimer() + b.RunParallel(func(b *testing.PB) { + for b.Next() { + resolver.LookupNetIP(context.Background(), "ip4", "www.google.com") + } + }) +}