diff --git a/client_dailer.go b/client_dailer.go index bc11f5a..ae63d35 100644 --- a/client_dailer.go +++ b/client_dailer.go @@ -8,7 +8,9 @@ import ( "net/http" "net/url" "reflect" + "runtime" "sync" + "sync/atomic" "time" "unsafe" ) @@ -19,8 +21,8 @@ type UDPDialer struct { MaxIdleConns int MaxConns int - mu sync.Mutex - conns []net.Conn + mu sync.Mutex + pool *sync.Pool } func (d *UDPDialer) DialContext(ctx context.Context, network, addr string) (conn net.Conn, err error) { @@ -31,38 +33,33 @@ func (d *UDPDialer) DialContext(ctx context.Context, network, addr string) (conn return } -func (d *UDPDialer) get() (conn net.Conn, err error) { - d.mu.Lock() - defer d.mu.Unlock() - - count := len(d.conns) - if d.MaxConns != 0 && count > d.MaxConns { - err = ErrMaxConns +func udpDialerFinalizer(c *net.UDPConn) { + _ = c.Close() +} - return - } - if count > 0 { - conn = d.conns[len(d.conns)-1] - d.conns = d.conns[:len(d.conns)-1] +func (d *UDPDialer) get() (conn net.Conn, err error) { + if atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&d.pool))) == nil { + d.mu.Lock() + pool := &sync.Pool{ + New: func() any { + c, err := net.DialUDP("udp", nil, d.Addr) + if err != nil { + return nil + } + runtime.SetFinalizer(c, udpDialerFinalizer) + return c + }, + } + atomic.StorePointer((*unsafe.Pointer)(unsafe.Pointer(&d.pool)), unsafe.Pointer(&pool)) + d.mu.Unlock() } - return + conn, _ = d.pool.Get().(net.Conn) + return conn, nil } func (d *UDPDialer) Put(conn net.Conn) { - if _, ok := conn.(*net.UDPConn); !ok { - return - } - - d.mu.Lock() - defer d.mu.Unlock() - - if (d.MaxIdleConns != 0 && len(d.conns) > d.MaxIdleConns) || (d.MaxConns != 0 && len(d.conns) > d.MaxConns) { - conn.Close() - return - } - - d.conns = append(d.conns, conn) + d.pool.Put(conn) } type HTTPDialer struct {