Skip to content

Commit

Permalink
use sync.Pool in UDPDialer
Browse files Browse the repository at this point in the history
  • Loading branch information
phuslu committed Oct 22, 2024
1 parent 585ec96 commit f358a6e
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 30 deletions.
61 changes: 31 additions & 30 deletions client_dailer.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,21 @@ import (
"net/http"
"net/url"
"reflect"
"runtime"
"sync"
"sync/atomic"
"time"
"unsafe"
)

type UDPDialer struct {
Addr *net.UDPAddr
Timeout time.Duration
MaxIdleConns int
MaxConns int
Addr *net.UDPAddr
Timeout time.Duration
MaxConns int64

mu sync.Mutex
conns []net.Conn
mu sync.Mutex
pool *sync.Pool
size int64
}

func (d *UDPDialer) DialContext(ctx context.Context, network, addr string) (conn net.Conn, err error) {
Expand All @@ -31,38 +33,37 @@ func (d *UDPDialer) DialContext(ctx context.Context, network, addr string) (conn
return
}

func (d *UDPDialer) get() (conn net.Conn, err error) {
d.mu.Lock()
defer d.mu.Unlock()

count := len(d.conns)
if d.MaxConns != 0 && count > d.MaxConns {
err = ErrMaxConns
func udpDialerFinalizer(c *net.UDPConn) {
_ = c.Close()
}

return
}
if count > 0 {
conn = d.conns[len(d.conns)-1]
d.conns = d.conns[:len(d.conns)-1]
func (d *UDPDialer) get() (conn net.Conn, err error) {
if atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&d.pool))) == nil {
d.mu.Lock()
pool := &sync.Pool{
New: func() any {
c, err := net.DialUDP("udp", nil, d.Addr)
if err != nil {
return nil
}
runtime.SetFinalizer(c, udpDialerFinalizer)
return c
},
}
atomic.StorePointer((*unsafe.Pointer)(unsafe.Pointer(&d.pool)), unsafe.Pointer(&pool))
d.mu.Unlock()
}

return
conn, _ = d.pool.Get().(net.Conn)
return conn, nil
}

func (d *UDPDialer) Put(conn net.Conn) {
if _, ok := conn.(*net.UDPConn); !ok {
return
}

d.mu.Lock()
defer d.mu.Unlock()

if (d.MaxIdleConns != 0 && len(d.conns) > d.MaxIdleConns) || (d.MaxConns != 0 && len(d.conns) > d.MaxConns) {
if atomic.AddInt64(&d.size, 1) <= d.MaxConns {
d.pool.Put(conn)
} else {
conn.Close()
return
}

d.conns = append(d.conns, conn)
}

type HTTPDialer struct {
Expand Down
42 changes: 42 additions & 0 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,3 +147,45 @@ func TestClientLookup(t *testing.T) {
}
}
}

func BenchmarkPureGoNetResolver(b *testing.B) {
resolver := net.Resolver{PreferGo: true}

b.ReportAllocs()
b.ResetTimer()
b.RunParallel(func(b *testing.PB) {
for b.Next() {
resolver.LookupNetIP(context.Background(), "ip4", "www.google.com")
}
})
}

func BenchmarkCGoNetResolver(b *testing.B) {
resolver := net.Resolver{PreferGo: false}

b.ReportAllocs()
b.ResetTimer()
b.RunParallel(func(b *testing.PB) {
for b.Next() {
resolver.LookupNetIP(context.Background(), "ip4", "www.google.com")
}
})
}

func BenchmarkFastdnsResolver(b *testing.B) {
resolver := Client{
Addr: "1.1.1.1:53",
Dialer: &UDPDialer{
Addr: func() (u *net.UDPAddr) { u, _ = net.ResolveUDPAddr("udp", "1.1.1.1:53"); return }(),
MaxConns: 1000,
},
}

b.ReportAllocs()
b.ResetTimer()
b.RunParallel(func(b *testing.PB) {
for b.Next() {
resolver.LookupNetIP(context.Background(), "ip4", "www.google.com")
}
})
}

0 comments on commit f358a6e

Please sign in to comment.