diff --git a/client_dailer.go b/client_dailer.go index bc11f5a..225a249 100644 --- a/client_dailer.go +++ b/client_dailer.go @@ -8,6 +8,7 @@ import ( "net/http" "net/url" "reflect" + "runtime" "sync" "time" "unsafe" @@ -19,8 +20,8 @@ type UDPDialer struct { MaxIdleConns int MaxConns int - mu sync.Mutex - conns []net.Conn + mu sync.Mutex + pool sync.Pool } func (d *UDPDialer) DialContext(ctx context.Context, network, addr string) (conn net.Conn, err error) { @@ -31,38 +32,30 @@ func (d *UDPDialer) DialContext(ctx context.Context, network, addr string) (conn return } -func (d *UDPDialer) get() (conn net.Conn, err error) { - d.mu.Lock() - defer d.mu.Unlock() - - count := len(d.conns) - if d.MaxConns != 0 && count > d.MaxConns { - err = ErrMaxConns +func udpDialerFinalizer(c *net.UDPConn) { + _ = c.Close() +} - return - } - if count > 0 { - conn = d.conns[len(d.conns)-1] - d.conns = d.conns[:len(d.conns)-1] +func (d *UDPDialer) get() (conn net.Conn, err error) { + if d.pool.New == nil { + d.mu.Lock() + d.pool.New = func() any { + c, err := net.DialUDP("udp", nil, d.Addr) + if err != nil { + return nil + } + runtime.SetFinalizer(c, udpDialerFinalizer) + return c + } + d.mu.Unlock() } - return + conn, _ = d.pool.Get().(net.Conn) + return conn, nil } func (d *UDPDialer) Put(conn net.Conn) { - if _, ok := conn.(*net.UDPConn); !ok { - return - } - - d.mu.Lock() - defer d.mu.Unlock() - - if (d.MaxIdleConns != 0 && len(d.conns) > d.MaxIdleConns) || (d.MaxConns != 0 && len(d.conns) > d.MaxConns) { - conn.Close() - return - } - - d.conns = append(d.conns, conn) + d.pool.Put(conn) } type HTTPDialer struct {