Skip to content

Commit

Permalink
Skip UDP address serialization in muxed conn (#688)
Browse files Browse the repository at this point in the history
  • Loading branch information
paulwe authored Apr 30, 2024
1 parent a834f55 commit 75ca3a2
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 178 deletions.
13 changes: 9 additions & 4 deletions udp_mux.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,6 @@ type UDPMuxDefault struct {
localAddrsForUnspecified []net.Addr
}

const maxAddrSize = 512

// UDPMuxParams are parameters for UDPMux.
type UDPMuxParams struct {
Logger logging.LeveledLogger
Expand Down Expand Up @@ -120,7 +118,7 @@ func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault {
pool: &sync.Pool{
New: func() interface{} {
// Big enough buffer to fit both packet and address
return newBufferHolder(receiveMTU + maxAddrSize)
return newBufferHolder(receiveMTU)
},
},
localAddrsForUnspecified: localAddrsForUnspecified,
Expand Down Expand Up @@ -365,7 +363,9 @@ func (m *UDPMuxDefault) getConn(ufrag string, isIPv6 bool) (val *udpMuxedConn, o
}

type bufferHolder struct {
buf []byte
next *bufferHolder
buf []byte
addr *net.UDPAddr
}

func newBufferHolder(size int) *bufferHolder {
Expand All @@ -374,6 +374,11 @@ func newBufferHolder(size int) *bufferHolder {
}
}

func (b *bufferHolder) reset() {
b.next = nil
b.addr = nil
}

type ipPort struct {
addr netip.Addr
port uint16
Expand Down
68 changes: 0 additions & 68 deletions udp_mux_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,74 +126,6 @@ func TestUDPMux(t *testing.T) {
}
}

func TestAddressEncoding(t *testing.T) {
cases := []struct {
name string
addr net.UDPAddr
}{
{
name: "empty address",
},
{
name: "ipv4",
addr: net.UDPAddr{
IP: net.IPv4(244, 120, 0, 5),
Port: 6000,
Zone: "",
},
},
{
name: "ipv6",
addr: net.UDPAddr{
IP: net.IPv6loopback,
Port: 2500,
Zone: "zone",
},
},
}

for _, c := range cases {
addr := c.addr
t.Run(c.name, func(t *testing.T) {
buf := make([]byte, maxAddrSize)
n, err := encodeUDPAddr(&addr, buf)
require.NoError(t, err)

parsedAddr, err := decodeUDPAddr(buf[:n])
require.NoError(t, err)
require.EqualValues(t, &addr, parsedAddr)
})
}
}

func BenchmarkAddressEncoding(b *testing.B) {
addr := &net.UDPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: 1234,
}
buf := make([]byte, 64)

b.Run("encode", func(b *testing.B) {
for i := 0; i < b.N; i++ {
if _, err := encodeUDPAddr(addr, buf); err != nil {
require.NoError(b, err)
}
}
})

b.Run("decode", func(b *testing.B) {
n, _ := encodeUDPAddr(addr, buf)
var addr *net.UDPAddr
var err error
for i := 0; i < b.N; i++ {
if addr, err = decodeUDPAddr(buf[:n]); err != nil {
require.NoError(b, err)
}
}
_ = addr
})
}

func testMuxConnection(t *testing.T, udpMux *UDPMuxDefault, ufrag string, network string) {
pktConn, err := udpMux.GetConn(ufrag, udpMux.LocalAddr())
require.NoError(t, err, "error retrieving muxed connection for ufrag")
Expand Down
199 changes: 93 additions & 106 deletions udp_muxed_conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,20 @@
package ice

import (
"encoding/binary"
"io"
"net"
"sync"
"time"

"github.com/pion/logging"
"github.com/pion/transport/v3/packetio"
)

type udpMuxedConnState int

const (
udpMuxedConnOpen udpMuxedConnState = iota
udpMuxedConnWaiting
udpMuxedConnClosed
)

type udpMuxedConnParams struct {
Expand All @@ -28,52 +34,61 @@ type udpMuxedConn struct {
// Remote addresses that we have sent to on this conn
addresses []ipPort

// Channel holding incoming packets
buf *packetio.Buffer
closedChan chan struct{}
closeOnce sync.Once
mu sync.Mutex
// FIFO queue holding incoming packets
bufHead, bufTail *bufferHolder
notify chan struct{}
closedChan chan struct{}
state udpMuxedConnState
mu sync.Mutex
}

func newUDPMuxedConn(params *udpMuxedConnParams) *udpMuxedConn {
p := &udpMuxedConn{
return &udpMuxedConn{
params: params,
buf: packetio.NewBuffer(),
notify: make(chan struct{}, 1),
closedChan: make(chan struct{}),
}

return p
}

func (c *udpMuxedConn) ReadFrom(b []byte) (n int, rAddr net.Addr, err error) {
buf := c.params.AddrPool.Get().(*bufferHolder) //nolint:forcetypeassert
defer c.params.AddrPool.Put(buf)

// Read address
total, err := c.buf.Read(buf.buf)
if err != nil {
return 0, nil, err
}

dataLen := int(binary.LittleEndian.Uint16(buf.buf[:2]))
if dataLen > total || dataLen > len(b) {
return 0, nil, io.ErrShortBuffer
}
for {
c.mu.Lock()
if c.bufTail != nil {
pkt := c.bufTail
c.bufTail = pkt.next

if pkt == c.bufHead {
c.bufHead = nil
}
c.mu.Unlock()

if len(b) < len(pkt.buf) {
err = io.ErrShortBuffer
} else {
n = copy(b, pkt.buf)
rAddr = pkt.addr
}

pkt.reset()
c.params.AddrPool.Put(pkt)

return
}

// Read data and then address
offset := 2
copy(b, buf.buf[offset:offset+dataLen])
offset += dataLen
if c.state == udpMuxedConnClosed {
c.mu.Unlock()
return 0, nil, io.EOF
}

// Read address len & decode address
addrLen := int(binary.LittleEndian.Uint16(buf.buf[offset : offset+2]))
offset += 2
c.state = udpMuxedConnWaiting
c.mu.Unlock()

if rAddr, err = decodeUDPAddr(buf.buf[offset : offset+addrLen]); err != nil {
return 0, nil, err
select {
case <-c.notify:
case <-c.closedChan:
return 0, nil, io.EOF
}
}

return dataLen, rAddr, nil
}

func (c *udpMuxedConn) WriteTo(buf []byte, rAddr net.Addr) (n int, err error) {
Expand Down Expand Up @@ -118,21 +133,28 @@ func (c *udpMuxedConn) CloseChannel() <-chan struct{} {
}

func (c *udpMuxedConn) Close() error {
var err error
c.closeOnce.Do(func() {
err = c.buf.Close()
c.mu.Lock()
defer c.mu.Unlock()
if c.state != udpMuxedConnClosed {
for pkt := c.bufTail; pkt != nil; {
next := pkt.next

pkt.reset()
c.params.AddrPool.Put(pkt)

pkt = next
}

c.state = udpMuxedConnClosed
close(c.closedChan)
})
return err
}
return nil
}

func (c *udpMuxedConn) isClosed() bool {
select {
case <-c.closedChan:
return true
default:
return false
}
c.mu.Lock()
defer c.mu.Unlock()
return c.state == udpMuxedConnClosed
}

func (c *udpMuxedConn) getAddresses() []ipPort {
Expand Down Expand Up @@ -178,79 +200,44 @@ func (c *udpMuxedConn) containsAddress(addr ipPort) bool {
}

func (c *udpMuxedConn) writePacket(data []byte, addr *net.UDPAddr) error {
// Write two packets, address and data
buf := c.params.AddrPool.Get().(*bufferHolder) //nolint:forcetypeassert
defer c.params.AddrPool.Put(buf)

// Format of buffer | data len | data bytes | addr len | addr bytes |
if len(buf.buf) < len(data)+maxAddrSize {
pkt := c.params.AddrPool.Get().(*bufferHolder) //nolint:forcetypeassert
if cap(pkt.buf) < len(data) {
c.params.AddrPool.Put(pkt)
return io.ErrShortBuffer
}
// Data length
binary.LittleEndian.PutUint16(buf.buf, uint16(len(data)))
offset := 2

// Data
copy(buf.buf[offset:], data)
offset += len(data)
pkt.buf = append(pkt.buf[:0], data...)
pkt.addr = addr

// Write address first, leaving room for its length
n, err := encodeUDPAddr(addr, buf.buf[offset+2:])
if err != nil {
return err
}
total := offset + n + 2
c.mu.Lock()
if c.state == udpMuxedConnClosed {
c.mu.Unlock()

// Address len
binary.LittleEndian.PutUint16(buf.buf[offset:], uint16(n))
pkt.reset()
c.params.AddrPool.Put(pkt)

if _, err := c.buf.Write(buf.buf[:total]); err != nil {
return err
return io.ErrClosedPipe
}
return nil
}

func encodeUDPAddr(addr *net.UDPAddr, buf []byte) (int, error) {
total := 1 + len(addr.IP) + 2 + len(addr.Zone)
if len(buf) < total {
return 0, io.ErrShortBuffer
if c.bufHead != nil {
c.bufHead.next = pkt
}
c.bufHead = pkt

buf[0] = uint8(len(addr.IP))
offset := 1

copy(buf[offset:], addr.IP)
offset += len(addr.IP)

binary.LittleEndian.PutUint16(buf[offset:], uint16(addr.Port))
offset += 2

copy(buf[offset:], addr.Zone)
return total, nil
}

func decodeUDPAddr(buf []byte) (*net.UDPAddr, error) {
addr := &net.UDPAddr{}

// Basic bounds checking
if len(buf) == 0 || len(buf) < int(buf[0])+3 {
return nil, io.ErrShortBuffer
if c.bufTail == nil {
c.bufTail = pkt
}

ipLen := int(buf[0])
offset := 1
state := c.state
c.state = udpMuxedConnOpen
c.mu.Unlock()

if ipLen == 0 {
addr.IP = nil
} else {
addr.IP = append(addr.IP[:0], buf[offset:offset+ipLen]...)
offset += ipLen
if state == udpMuxedConnWaiting {
select {
case c.notify <- struct{}{}:
default:
}
}

addr.Port = int(binary.LittleEndian.Uint16(buf[offset:]))
offset += 2

addr.Zone = string(buf[offset:])

return addr, nil
return nil
}

0 comments on commit 75ca3a2

Please sign in to comment.