From 9361e3eacd2a13007017800a25dd942adf567c6e Mon Sep 17 00:00:00 2001 From: caffix Date: Fri, 2 Feb 2024 16:41:31 -0500 Subject: [PATCH] added network connection rotations --- conn.go | 137 +++++++++++++++++++++++++++++++++++++------------------- 1 file changed, 91 insertions(+), 46 deletions(-) diff --git a/conn.go b/conn.go index 4a39f88..cacae9d 100644 --- a/conn.go +++ b/conn.go @@ -6,6 +6,7 @@ package resolve import ( "context" + "errors" "fmt" "net" "runtime" @@ -25,82 +26,123 @@ type resp struct { Addr net.Addr } +type connection struct { + conn net.PacketConn + done chan struct{} +} + type connections struct { sync.Mutex done chan struct{} - conns []net.PacketConn + conns []*connection resps queue.Queue nextWrite int + cpus int } func newConnections(cpus int, resps queue.Queue) *connections { conns := &connections{ - done: make(chan struct{}, 1), resps: resps, + done: make(chan struct{}), + cpus: cpus, } + conns.Lock() + defer conns.Unlock() + for i := 0; i < cpus; i++ { if err := conns.Add(); err != nil { conns.Close() return nil } } + go conns.rotations() return conns } -func (c *connections) Close() { - select { - case <-c.done: - return - default: +func (r *connections) Close() { + r.Lock() + defer r.Unlock() + + if r.conns != nil { + close(r.done) + for _, c := range r.conns { + close(c.done) + } + r.conns = nil } - close(c.done) - for _, conn := range c.conns { - conn.Close() +} + +func (r *connections) rotations() { + t := time.NewTicker(time.Minute) + defer t.Stop() + + for { + select { + case <-r.done: + return + case <-t.C: + r.rotate() + } } } -func (c *connections) Next() net.PacketConn { - c.Lock() - defer c.Unlock() +func (r *connections) rotate() { + r.Lock() + defer r.Unlock() - cur := c.nextWrite - c.nextWrite = (c.nextWrite + 1) % len(c.conns) - return c.conns[cur] + for _, c := range r.conns { + go func(c *connection) { + t := time.NewTimer(10 * time.Second) + defer t.Stop() + + <-t.C + close(c.done) + }(c) + } + + r.conns = []*connection{} + for i := 0; i < r.cpus; i++ { + _ = r.Add() + } } -func (c *connections) Add() error { +func (r *connections) Next() net.PacketConn { + r.Lock() + defer r.Unlock() + + if r.conns == nil || len(r.conns) == 0 { + return nil + } + + cur := r.nextWrite + r.nextWrite = (r.nextWrite + 1) % len(r.conns) + return r.conns[cur].conn +} + +func (r *connections) Add() error { var err error var conn net.PacketConn - switch runtime.GOOS { - case "android": - fallthrough - case "linux": - fallthrough - case "darwin": - fallthrough - case "freebsd": - fallthrough - case "netbsd": - fallthrough - case "openbsd": - fallthrough - case "solaris": - conn, err = c.unixListenPacket() - default: + if runtime.GOOS == "linux" { + conn, err = r.linuxListenPacket() + } else { conn, err = net.ListenPacket("udp", ":0") } if err == nil { _ = conn.SetDeadline(time.Time{}) - c.conns = append(c.conns, conn) - go c.responses(conn) + c := &connection{ + conn: conn, + done: make(chan struct{}), + } + r.conns = append(r.conns, c) + go r.responses(c) } return err } -func (c *connections) unixListenPacket() (net.PacketConn, error) { +func (r *connections) linuxListenPacket() (net.PacketConn, error) { lc := net.ListenConfig{ Control: func(network, address string, c syscall.RawConn) error { var operr error @@ -116,43 +158,46 @@ func (c *connections) unixListenPacket() (net.PacketConn, error) { } laddr := ":0" - if len(c.conns) > 0 { - laddr = c.conns[0].LocalAddr().String() + if len(r.conns) > 0 { + laddr = r.conns[0].conn.LocalAddr().String() } return lc.ListenPacket(context.Background(), "udp", laddr) } -func (c *connections) WriteMsg(msg *dns.Msg, addr net.Addr) error { +func (r *connections) WriteMsg(msg *dns.Msg, addr net.Addr) error { var n int var err error var out []byte if out, err = msg.Pack(); err == nil { - conn := c.Next() + err = errors.New("failed to obtain a connection") - _ = conn.SetWriteDeadline(time.Now().Add(500 * time.Millisecond)) - if n, err = conn.WriteTo(out, addr); err == nil && n < len(out) { - err = fmt.Errorf("only wrote %d bytes of the %d byte message", n, len(out)) + if conn := r.Next(); conn != nil { + _ = conn.SetWriteDeadline(time.Now().Add(500 * time.Millisecond)) + if n, err = conn.WriteTo(out, addr); err == nil && n < len(out) { + err = fmt.Errorf("only wrote %d bytes of the %d byte message", n, len(out)) + } } } return err } -func (c *connections) responses(conn net.PacketConn) { +func (r *connections) responses(c *connection) { b := make([]byte, dns.DefaultMsgSize) for { select { case <-c.done: + _ = c.conn.Close() return default: } - if n, addr, err := conn.ReadFrom(b); err == nil && n >= headerSize { + if n, addr, err := c.conn.ReadFrom(b); err == nil && n >= headerSize { m := new(dns.Msg) if err := m.Unpack(b[:n]); err == nil && len(m.Question) > 0 { - c.resps.Append(&resp{ + r.resps.Append(&resp{ Msg: m, Addr: addr, })