diff --git a/common/leakybuf/leakybuf.go b/common/leakybuf/leakybuf.go index 76ae7f9..24f8066 100644 --- a/common/leakybuf/leakybuf.go +++ b/common/leakybuf/leakybuf.go @@ -1,3 +1,5 @@ +// modified from https://github.com/nadoo/glider/blob/master/pool/buffer.go + package leakybuf import ( @@ -5,74 +7,53 @@ import ( "sync" ) -var pool = make(map[int]chan []byte) -var mu sync.RWMutex +const ( + // number of pools. + num = 17 + maxsize = 1 << (num - 1) + UDPBufSize = 64 * 1024 +) -const maxPoolSize = 256 -const UDPBufSize = 64 * 1024 -const maxSize = 1 << 17 +var ( + sizes [num]int + pools [num]sync.Pool +) -func getClosestSize(need int) (size int) { - // if need is exactly 2^n, return it - if need&(need-1) == 0 { - return need +func init() { + for i := 0; i < num; i++ { + size := 1 << i + sizes[i] = size + pools[i].New = func() interface{} { + return make([]byte, size) + } } - // or return its closest 2^n - return 1 << bits.Len(uint(need)) } -func Get(need int) []byte { - if need > maxSize { - return make([]byte, need) - } - size := getClosestSize(need) - mu.RLock() - c, ok := pool[size] - if !ok { - mu.RUnlock() - mu.Lock() - if c, ok = pool[size]; !ok { - pool[size] = make(chan []byte, maxPoolSize) - mu.Unlock() - return make([]byte, need, size) - } - mu.Unlock() - } else { - mu.RUnlock() +func getClosestN(need int) (n int) { + // if need is exactly 2^n, return n-1 + if need&(need-1) == 0 { + return bits.Len32(uint32(need)) - 1 } - select { - case buf := <-c: - return buf[:need] - default: + // or return its closest n + return bits.Len32(uint32(need)) +} + +// Get gets a buffer from pool, size should in range: [1, 65536], +// otherwise, this function will call make([]byte, size) directly. +func Get(size int) []byte { + if size >= 1 && size <= maxsize { + i := getClosestN(size) + return pools[i].Get().([]byte)[:size] } - return make([]byte, need, size) + return make([]byte, size) } +// Put puts a buffer into pool. func Put(buf []byte) { - size := cap(buf) - if size > maxSize { - return - } - mu.RLock() - c, ok := pool[size] - if ok { - mu.RUnlock() - select { - case c <- buf[:size]: - default: - } - } else { - mu.RUnlock() - mu.Lock() - if c, ok = pool[size]; !ok { - pool[size] = make(chan []byte, maxPoolSize) - mu.Unlock() - } else { - mu.Unlock() - } - select { - case pool[size] <- buf[:size]: - default: + if size := cap(buf); size >= 1 && size <= maxsize { + i := getClosestN(size) + if i < num { + pools[i].Put(buf) } } } diff --git a/config/config.go b/config/config.go index 8ae7e86..2179555 100644 --- a/config/config.go +++ b/config/config.go @@ -148,6 +148,7 @@ func GetConfig() *Config { if err = json.Unmarshal(b, config); err != nil { log.Fatalln(err) } + log.Println("pulling configures from upstreams...") if err = parseUpstreams(config); err != nil { log.Fatalln(err) } diff --git a/dispatcher/udp/udp.go b/dispatcher/udp/udp.go index 805515a..2f0b1e3 100644 --- a/dispatcher/udp/udp.go +++ b/dispatcher/udp/udp.go @@ -33,20 +33,21 @@ func (d *Dispatcher) Listen() (err error) { } defer d.c.Close() log.Printf("[udp] listen on :%v\n", d.group.Port) + var buf [leakybuf.UDPBufSize]byte for { - buf := leakybuf.Get(leakybuf.UDPBufSize) - n, laddr, err := d.c.ReadFrom(buf) + n, laddr, err := d.c.ReadFrom(buf[:]) if err != nil { log.Printf("[error] ReadFrom: %v", err) - leakybuf.Put(buf) continue } + data := leakybuf.Get(n) + copy(data, buf[:n]) go func() { - err := d.handleConn(laddr, buf, n) + err := d.handleConn(laddr, data, n) if err != nil { log.Println(err) } - leakybuf.Put(buf) + leakybuf.Put(data) }() } } @@ -56,7 +57,7 @@ func selectTimeout(packet []byte) time.Duration { var dmessage dnsmessage.Message err := dmessage.Unpack(packet) if err != nil { - return defaultTimeout + return defaultNatTimeout } return dnsQueryTimeout } @@ -80,7 +81,6 @@ func (d *Dispatcher) handleConn(laddr net.Addr, data []byte, n int) (err error) } // send packet - log.Printf("[udp] %s <-> %s <-> %s", laddr.String(), d.c.LocalAddr(), rc.RemoteAddr()) _, err = rc.Write(data[:n]) if err != nil { return fmt.Errorf("[udp] handleConn write error: %v", err) @@ -89,14 +89,14 @@ func (d *Dispatcher) handleConn(laddr net.Addr, data []byte, n int) (err error) } // connTimeout is the timeout of connection to build if not exists -func (d *Dispatcher) GetOrBuildUCPConn(laddr net.Addr, target string, connTimeout time.Duration) (rc *net.UDPConn, err error) { +func (d *Dispatcher) GetOrBuildUCPConn(laddr net.Addr, target string, natTimeout time.Duration) (rc *UDPConn, err error) { socketIdent := laddr.String() d.nm.Lock() var conn *UDPConn var ok bool if conn, ok = d.nm.Get(socketIdent); !ok { // not exist such socket mapping, build one - d.nm.Insert(socketIdent, nil, 3600*time.Second) + d.nm.Insert(socketIdent, nil) d.nm.Unlock() rconn, err := net.Dial("udp", target) if err != nil { @@ -105,14 +105,16 @@ func (d *Dispatcher) GetOrBuildUCPConn(laddr net.Addr, target string, connTimeou d.nm.Unlock() return nil, fmt.Errorf("GetOrBuildUCPConn dial error: %v", err) } - rc = rconn.(*net.UDPConn) + _rconn := rconn.(*net.UDPConn) d.nm.Lock() d.nm.Remove(socketIdent) // close channel to inform that establishment ends - d.nm.Insert(socketIdent, rc, connTimeout) + d.nm.Insert(socketIdent, _rconn) + rc, _ = d.nm.Get(socketIdent) d.nm.Unlock() // relay + log.Printf("[udp] %s <-> %s <-> %s", laddr.String(), d.c.LocalAddr(), rc.RemoteAddr()) go func() { - _ = relay(d.c, laddr, rc, connTimeout) + _ = relay(d.c, laddr, _rconn, natTimeout) d.nm.Lock() d.nm.Remove(socketIdent) d.nm.Unlock() @@ -123,10 +125,10 @@ func (d *Dispatcher) GetOrBuildUCPConn(laddr net.Addr, target string, connTimeou <-conn.Establishing if conn.UDPConn == nil { // establishment ended and retrieve the result - return d.GetOrBuildUCPConn(laddr, target, connTimeout) + return d.GetOrBuildUCPConn(laddr, target, natTimeout) } else { // establishment succeeded - rc = conn.UDPConn + rc = conn } } return rc, nil @@ -142,6 +144,7 @@ func relay(dst *net.UDPConn, laddr net.Addr, src *net.UDPConn, timeout time.Dura if err != nil { return } + _ = dst.SetWriteDeadline(time.Now().Add(timeout)) _, err = dst.WriteTo(buf[:n], laddr) if err != nil { return diff --git a/dispatcher/udp/udpConn.go b/dispatcher/udp/udpConn.go index 5039a23..7fb7dda 100644 --- a/dispatcher/udp/udpConn.go +++ b/dispatcher/udp/udpConn.go @@ -7,23 +7,19 @@ import ( ) const ( - defaultTimeout = 120 * time.Second - dnsQueryTimeout = 17 * time.Second + defaultNatTimeout = 2 * time.Minute + dnsQueryTimeout = 17 * time.Second ) type UDPConn struct { Establishing chan struct{} *net.UDPConn - lastVisitTime time.Time - timeout time.Duration } -func NewUDPConn(conn *net.UDPConn, timeout time.Duration) *UDPConn { +func NewUDPConn(conn *net.UDPConn) *UDPConn { c := &UDPConn{ - UDPConn: conn, - lastVisitTime: time.Now(), - timeout: timeout, - Establishing: make(chan struct{}), + UDPConn: conn, + Establishing: make(chan struct{}), } if c.UDPConn != nil { close(c.Establishing) @@ -34,50 +30,26 @@ func NewUDPConn(conn *net.UDPConn, timeout time.Duration) *UDPConn { type UDPConnMapping struct { nm map[string]*UDPConn sync.Mutex - cleanTicker *time.Ticker -} - -func (m *UDPConnMapping) cleaner() { - for t := range m.cleanTicker.C { - m.Lock() - for k, v := range m.nm { - if t.Sub(v.lastVisitTime) > v.timeout { - delete(m.nm, k) - } - } - m.Unlock() - } } func NewUDPConnMapping() *UDPConnMapping { m := &UDPConnMapping{ - nm: make(map[string]*UDPConn), - cleanTicker: time.NewTicker(2 * time.Second), + nm: make(map[string]*UDPConn), } - go m.cleaner() return m } -func (m *UDPConnMapping) Close() error { - m.cleanTicker.Stop() - return nil -} - func (m *UDPConnMapping) Get(key string) (conn *UDPConn, ok bool) { v, ok := m.nm[key] if ok { - if time.Since(v.lastVisitTime) > defaultTimeout { - return nil, false - } - v.lastVisitTime = time.Now() conn = v } return } // pass val=nil for stating it is establishing -func (m *UDPConnMapping) Insert(key string, val *net.UDPConn, timeout time.Duration) { - m.nm[key] = NewUDPConn(val, timeout) +func (m *UDPConnMapping) Insert(key string, val *net.UDPConn) { + m.nm[key] = NewUDPConn(val) } func (m *UDPConnMapping) Remove(key string) {