Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Eliminate udp proxy #2712

Draft
wants to merge 13 commits into
base: relay/fix/wg-roaming
Choose a base branch
from
5 changes: 5 additions & 0 deletions client/iface/bind/endpoint.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package bind

import wgConn "golang.zx2c4.com/wireguard/conn"

type Endpoint = wgConn.StdNetEndpoint
118 changes: 114 additions & 4 deletions client/iface/bind/bind.go → client/iface/bind/ice_bind.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"fmt"
"net"
"runtime"
"strings"
"sync"

"github.com/pion/stun/v2"
Expand All @@ -13,6 +14,11 @@ import (
wgConn "golang.zx2c4.com/wireguard/conn"
)

type RecvMessage struct {
Endpoint *Endpoint
Buffer []byte
}

type receiverCreator struct {
iceBind *ICEBind
}
Expand All @@ -23,19 +29,32 @@ func (rc receiverCreator) CreateIPv4ReceiverFn(msgPool *sync.Pool, pc *ipv4.Pack

type ICEBind struct {
*wgConn.StdNetBind

muUDPMux sync.Mutex
RecvChan chan RecvMessage

transportNet transport.Net
udpMux *UniversalUDPMuxDefault
filterFn FilterFn
endpoints map[string]net.Conn
endpointsMu sync.Mutex
// every time when Close() is called (i.e. BindUpdate()) we need to close exit from the receiveRelayed and create a
// new closed channel. With the closedChanMu we can safely close the channel and create a new one
closedChan chan struct{}
closedChanMu sync.RWMutex
closed bool

filterFn FilterFn
muUDPMux sync.Mutex
udpMux *UniversalUDPMuxDefault
}

func NewICEBind(transportNet transport.Net, filterFn FilterFn) *ICEBind {
b, _ := wgConn.NewStdNetBind().(*wgConn.StdNetBind)
ib := &ICEBind{
StdNetBind: b,
RecvChan: make(chan RecvMessage, 1),
transportNet: transportNet,
filterFn: filterFn,
endpoints: make(map[string]net.Conn),
closedChan: make(chan struct{}),
closed: true,
}

rc := receiverCreator{
Expand All @@ -45,6 +64,31 @@ func NewICEBind(transportNet transport.Net, filterFn FilterFn) *ICEBind {
return ib
}

func (s *ICEBind) Open(uport uint16) ([]wgConn.ReceiveFunc, uint16, error) {
s.closed = false
s.closedChanMu.Lock()
s.closedChan = make(chan struct{})
s.closedChanMu.Unlock()
fns, port, err := s.StdNetBind.Open(uport)
if err != nil {
return nil, 0, err
}
fns = append(fns, s.receiveRelayed)
return fns, port, nil
}

func (s *ICEBind) Close() error {
// just a quick implementation to make the tests pass
if s.closed {
return nil
}
s.closed = true
close(s.closedChan)
err := s.StdNetBind.Close()
return err

}

// GetICEMux returns the ICE UDPMux that was created and used by ICEBind
func (s *ICEBind) GetICEMux() (*UniversalUDPMuxDefault, error) {
s.muUDPMux.Lock()
Expand All @@ -56,6 +100,39 @@ func (s *ICEBind) GetICEMux() (*UniversalUDPMuxDefault, error) {
return s.udpMux, nil
}

func (b *ICEBind) SetEndpoint(peerAddress *net.UDPAddr, conn net.Conn) (*net.UDPAddr, error) {
fakeAddr, err := fakeAddress(peerAddress)
if err != nil {
return nil, err
}
b.endpointsMu.Lock()
b.endpoints[fakeAddr.String()] = conn
b.endpointsMu.Unlock()
return fakeAddr, nil
}

func (b *ICEBind) RemoveEndpoint(fakeAddr *net.UDPAddr) {
b.endpointsMu.Lock()
defer b.endpointsMu.Unlock()
delete(b.endpoints, fakeAddr.String())
}

func (b *ICEBind) Send(bufs [][]byte, ep wgConn.Endpoint) error {
b.endpointsMu.Lock()
conn, ok := b.endpoints[ep.DstToString()]
b.endpointsMu.Unlock()
if !ok {
return b.StdNetBind.Send(bufs, ep)
}

for _, buf := range bufs {
if _, err := conn.Write(buf); err != nil {
return err
}
}
return nil
}

func (s *ICEBind) createIPv4ReceiverFn(ipv4MsgsPool *sync.Pool, pc *ipv4.PacketConn, conn *net.UDPConn) wgConn.ReceiveFunc {
s.muUDPMux.Lock()
defer s.muUDPMux.Unlock()
Expand Down Expand Up @@ -140,3 +217,36 @@ func (s *ICEBind) parseSTUNMessage(raw []byte) (*stun.Message, error) {

return msg, nil
}

// receiveRelayed is a receive function that is used to receive packets from the relayed connection and forward to the
// WireGuard. Critical part is do not block if the Closed() has been called.
func (c *ICEBind) receiveRelayed(buffs [][]byte, sizes []int, eps []wgConn.Endpoint) (int, error) {
c.closedChanMu.RLock()
defer c.closedChanMu.RUnlock()

select {
case <-c.closedChan:
return 0, net.ErrClosed
case msg, ok := <-c.RecvChan:
if !ok {
return 0, net.ErrClosed
}
copy(buffs[0], msg.Buffer)
sizes[0] = len(msg.Buffer)
eps[0] = wgConn.Endpoint(msg.Endpoint)
return 1, nil
}
}

func fakeAddress(peerAddress *net.UDPAddr) (*net.UDPAddr, error) {
octets := strings.Split(peerAddress.IP.String(), ".")
if len(octets) != 4 {
return nil, fmt.Errorf("invalid IP format")
}

newAddr := &net.UDPAddr{
IP: net.ParseIP(fmt.Sprintf("127.1.%s.%s", octets[2], octets[3])),
Port: peerAddress.Port,
}
return newAddr, nil
}
5 changes: 2 additions & 3 deletions client/iface/device/device_darwin.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"fmt"
"os/exec"

"github.com/pion/transport/v3"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun"
Expand All @@ -29,14 +28,14 @@ type TunDevice struct {
configurer WGConfigurer
}

func NewTunDevice(name string, address WGAddress, port int, key string, mtu int, transportNet transport.Net, filterFn bind.FilterFn) *TunDevice {
func NewTunDevice(name string, address WGAddress, port int, key string, mtu int, iceBind *bind.ICEBind) *TunDevice {
return &TunDevice{
name: name,
address: address,
port: port,
key: key,
mtu: mtu,
iceBind: bind.NewICEBind(transportNet, filterFn),
iceBind: iceBind,
}
}

Expand Down
5 changes: 2 additions & 3 deletions client/iface/device/device_netstack.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ package device
import (
"fmt"

"github.com/pion/transport/v3"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/device"

Expand All @@ -31,15 +30,15 @@ type TunNetstackDevice struct {
configurer WGConfigurer
}

func NewNetstackDevice(name string, address WGAddress, wgPort int, key string, mtu int, transportNet transport.Net, listenAddress string, filterFn bind.FilterFn) *TunNetstackDevice {
func NewNetstackDevice(name string, address WGAddress, wgPort int, key string, mtu int, iceBind *bind.ICEBind, listenAddress string) *TunNetstackDevice {
return &TunNetstackDevice{
name: name,
address: address,
port: wgPort,
key: key,
mtu: mtu,
listenAddress: listenAddress,
iceBind: bind.NewICEBind(transportNet, filterFn),
iceBind: iceBind,
}
}

Expand Down
6 changes: 3 additions & 3 deletions client/iface/device/device_usp_unix.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"os"
"runtime"

"github.com/pion/transport/v3"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun"
Expand All @@ -30,7 +29,7 @@ type USPDevice struct {
configurer WGConfigurer
}

func NewUSPDevice(name string, address WGAddress, port int, key string, mtu int, transportNet transport.Net, filterFn bind.FilterFn) *USPDevice {
func NewUSPDevice(name string, address WGAddress, port int, key string, mtu int, iceBind *bind.ICEBind) *USPDevice {
log.Infof("using userspace bind mode")

checkUser()
Expand All @@ -41,7 +40,8 @@ func NewUSPDevice(name string, address WGAddress, port int, key string, mtu int,
port: port,
key: key,
mtu: mtu,
iceBind: bind.NewICEBind(transportNet, filterFn)}
iceBind: iceBind,
}
}

func (t *USPDevice) Create() (WGConfigurer, error) {
Expand Down
34 changes: 23 additions & 11 deletions client/iface/iface.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,15 @@ import (
"sync"
"time"

"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"

"github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgproxy"
)

const (
Expand All @@ -28,8 +31,13 @@ type WGIface struct {
userspaceBind bool
mu sync.Mutex

configurer device.WGConfigurer
filter device.PacketFilter
configurer device.WGConfigurer
filter device.PacketFilter
wgProxyFactory *wgproxy.Factory
}

func (w *WGIface) GetProxy() wgproxy.Proxy {
return w.wgProxyFactory.GetProxy()
}

// IsUserspaceBind indicates whether this interfaces is userspace with bind.ICEBind
Expand Down Expand Up @@ -124,22 +132,26 @@ func (w *WGIface) Close() error {
w.mu.Lock()
defer w.mu.Unlock()

err := w.tun.Close()
if err != nil {
return fmt.Errorf("failed to close wireguard interface %s: %w", w.Name(), err)
var result *multierror.Error

if err := w.wgProxyFactory.Free(); err != nil {
result = multierror.Append(result, fmt.Errorf("failed to free WireGuard proxy: %w", err))
}

err = w.waitUntilRemoved()
if err != nil {
if err := w.tun.Close(); err != nil {
result = multierror.Append(result, fmt.Errorf("failed to close wireguard interface %s: %w", w.Name(), err))
}

if err := w.waitUntilRemoved(); err != nil {
log.Warnf("failed to remove WireGuard interface %s: %v", w.Name(), err)
err = w.Destroy()
if err != nil {
return fmt.Errorf("failed to remove WireGuard interface %s: %w", w.Name(), err)
if err := w.Destroy(); err != nil {
result = multierror.Append(result, fmt.Errorf("failed to remove WireGuard interface %s: %w", w.Name(), err))
return errors.FormatErrorOrNil(result)
}
log.Infof("interface %s successfully removed", w.Name())
}

return nil
return errors.FormatErrorOrNil(result)
}

// SetFilter sets packet filters for the userspace implementation
Expand Down
10 changes: 7 additions & 3 deletions client/iface/iface_darwin.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/iface/wgproxy"
)

// NewWGIFace Creates a new WireGuard interface instance
Expand All @@ -21,16 +22,19 @@ func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string,
return nil, err
}

iceBind := bind.NewICEBind(transportNet, filterFn)

wgIFace := &WGIface{
userspaceBind: true,
userspaceBind: true,
wgProxyFactory: wgproxy.NewFactory(wgPort, iceBind),
}

if netstack.IsEnabled() {
wgIFace.tun = device.NewNetstackDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, netstack.ListenAddr(), filterFn)
wgIFace.tun = device.NewNetstackDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, iceBind, netstack.ListenAddr())
return wgIFace, nil
}

wgIFace.tun = device.NewTunDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, filterFn)
wgIFace.tun = device.NewTunDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, iceBind)

return wgIFace, nil
}
Expand Down
7 changes: 7 additions & 0 deletions client/iface/iface_moc.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgproxy"
)

type MockWGIface struct {
Expand All @@ -30,6 +31,7 @@ type MockWGIface struct {
GetDeviceFunc func() *device.FilteredDevice
GetStatsFunc func(peerKey string) (configurer.WGStats, error)
GetInterfaceGUIDStringFunc func() (string, error)
GetProxyFunc func() wgproxy.Proxy
}

func (m *MockWGIface) GetInterfaceGUIDString() (string, error) {
Expand Down Expand Up @@ -103,3 +105,8 @@ func (m *MockWGIface) GetDevice() *device.FilteredDevice {
func (m *MockWGIface) GetStats(peerKey string) (configurer.WGStats, error) {
return m.GetStatsFunc(peerKey)
}

func (m *MockWGIface) GetProxy() wgproxy.Proxy {
//TODO implement me
panic("implement me")
}
Loading
Loading