diff --git a/client/iface/bind/endpoint.go b/client/iface/bind/endpoint.go new file mode 100644 index 0000000000..1926ff88f1 --- /dev/null +++ b/client/iface/bind/endpoint.go @@ -0,0 +1,5 @@ +package bind + +import wgConn "golang.zx2c4.com/wireguard/conn" + +type Endpoint = wgConn.StdNetEndpoint diff --git a/client/iface/bind/bind.go b/client/iface/bind/ice_bind.go similarity index 50% rename from client/iface/bind/bind.go rename to client/iface/bind/ice_bind.go index ba6153cb73..6bf24b3c95 100644 --- a/client/iface/bind/bind.go +++ b/client/iface/bind/ice_bind.go @@ -4,6 +4,7 @@ import ( "fmt" "net" "runtime" + "strings" "sync" "github.com/pion/stun/v2" @@ -13,6 +14,11 @@ import ( wgConn "golang.zx2c4.com/wireguard/conn" ) +type RecvMessage struct { + Endpoint *Endpoint + Buffer []byte +} + type receiverCreator struct { iceBind *ICEBind } @@ -23,19 +29,32 @@ func (rc receiverCreator) CreateIPv4ReceiverFn(msgPool *sync.Pool, pc *ipv4.Pack type ICEBind struct { *wgConn.StdNetBind - - muUDPMux sync.Mutex + RecvChan chan RecvMessage transportNet transport.Net - udpMux *UniversalUDPMuxDefault + filterFn FilterFn + endpoints map[string]net.Conn + endpointsMu sync.Mutex + // every time when Close() is called (i.e. BindUpdate()) we need to close exit from the receiveRelayed and create a + // new closed channel. With the closedChanMu we can safely close the channel and create a new one + closedChan chan struct{} + closedChanMu sync.RWMutex + closed bool - filterFn FilterFn + muUDPMux sync.Mutex + udpMux *UniversalUDPMuxDefault } func NewICEBind(transportNet transport.Net, filterFn FilterFn) *ICEBind { + b, _ := wgConn.NewStdNetBind().(*wgConn.StdNetBind) ib := &ICEBind{ + StdNetBind: b, + RecvChan: make(chan RecvMessage, 1), transportNet: transportNet, filterFn: filterFn, + endpoints: make(map[string]net.Conn), + closedChan: make(chan struct{}), + closed: true, } rc := receiverCreator{ @@ -45,6 +64,31 @@ func NewICEBind(transportNet transport.Net, filterFn FilterFn) *ICEBind { return ib } +func (s *ICEBind) Open(uport uint16) ([]wgConn.ReceiveFunc, uint16, error) { + s.closed = false + s.closedChanMu.Lock() + s.closedChan = make(chan struct{}) + s.closedChanMu.Unlock() + fns, port, err := s.StdNetBind.Open(uport) + if err != nil { + return nil, 0, err + } + fns = append(fns, s.receiveRelayed) + return fns, port, nil +} + +func (s *ICEBind) Close() error { + // just a quick implementation to make the tests pass + if s.closed { + return nil + } + s.closed = true + close(s.closedChan) + err := s.StdNetBind.Close() + return err + +} + // GetICEMux returns the ICE UDPMux that was created and used by ICEBind func (s *ICEBind) GetICEMux() (*UniversalUDPMuxDefault, error) { s.muUDPMux.Lock() @@ -56,6 +100,39 @@ func (s *ICEBind) GetICEMux() (*UniversalUDPMuxDefault, error) { return s.udpMux, nil } +func (b *ICEBind) SetEndpoint(peerAddress *net.UDPAddr, conn net.Conn) (*net.UDPAddr, error) { + fakeAddr, err := fakeAddress(peerAddress) + if err != nil { + return nil, err + } + b.endpointsMu.Lock() + b.endpoints[fakeAddr.String()] = conn + b.endpointsMu.Unlock() + return fakeAddr, nil +} + +func (b *ICEBind) RemoveEndpoint(fakeAddr *net.UDPAddr) { + b.endpointsMu.Lock() + defer b.endpointsMu.Unlock() + delete(b.endpoints, fakeAddr.String()) +} + +func (b *ICEBind) Send(bufs [][]byte, ep wgConn.Endpoint) error { + b.endpointsMu.Lock() + conn, ok := b.endpoints[ep.DstToString()] + b.endpointsMu.Unlock() + if !ok { + return b.StdNetBind.Send(bufs, ep) + } + + for _, buf := range bufs { + if _, err := conn.Write(buf); err != nil { + return err + } + } + return nil +} + func (s *ICEBind) createIPv4ReceiverFn(ipv4MsgsPool *sync.Pool, pc *ipv4.PacketConn, conn *net.UDPConn) wgConn.ReceiveFunc { s.muUDPMux.Lock() defer s.muUDPMux.Unlock() @@ -140,3 +217,36 @@ func (s *ICEBind) parseSTUNMessage(raw []byte) (*stun.Message, error) { return msg, nil } + +// receiveRelayed is a receive function that is used to receive packets from the relayed connection and forward to the +// WireGuard. Critical part is do not block if the Closed() has been called. +func (c *ICEBind) receiveRelayed(buffs [][]byte, sizes []int, eps []wgConn.Endpoint) (int, error) { + c.closedChanMu.RLock() + defer c.closedChanMu.RUnlock() + + select { + case <-c.closedChan: + return 0, net.ErrClosed + case msg, ok := <-c.RecvChan: + if !ok { + return 0, net.ErrClosed + } + copy(buffs[0], msg.Buffer) + sizes[0] = len(msg.Buffer) + eps[0] = wgConn.Endpoint(msg.Endpoint) + return 1, nil + } +} + +func fakeAddress(peerAddress *net.UDPAddr) (*net.UDPAddr, error) { + octets := strings.Split(peerAddress.IP.String(), ".") + if len(octets) != 4 { + return nil, fmt.Errorf("invalid IP format") + } + + newAddr := &net.UDPAddr{ + IP: net.ParseIP(fmt.Sprintf("127.1.%s.%s", octets[2], octets[3])), + Port: peerAddress.Port, + } + return newAddr, nil +} diff --git a/client/iface/device/device_darwin.go b/client/iface/device/device_darwin.go index 03e85a7f17..b5a128bc1c 100644 --- a/client/iface/device/device_darwin.go +++ b/client/iface/device/device_darwin.go @@ -6,7 +6,6 @@ import ( "fmt" "os/exec" - "github.com/pion/transport/v3" log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/tun" @@ -29,14 +28,14 @@ type TunDevice struct { configurer WGConfigurer } -func NewTunDevice(name string, address WGAddress, port int, key string, mtu int, transportNet transport.Net, filterFn bind.FilterFn) *TunDevice { +func NewTunDevice(name string, address WGAddress, port int, key string, mtu int, iceBind *bind.ICEBind) *TunDevice { return &TunDevice{ name: name, address: address, port: port, key: key, mtu: mtu, - iceBind: bind.NewICEBind(transportNet, filterFn), + iceBind: iceBind, } } diff --git a/client/iface/device/device_netstack.go b/client/iface/device/device_netstack.go index 440a1ca191..f5d39e9e07 100644 --- a/client/iface/device/device_netstack.go +++ b/client/iface/device/device_netstack.go @@ -6,7 +6,6 @@ package device import ( "fmt" - "github.com/pion/transport/v3" log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/device" @@ -31,7 +30,7 @@ type TunNetstackDevice struct { configurer WGConfigurer } -func NewNetstackDevice(name string, address WGAddress, wgPort int, key string, mtu int, transportNet transport.Net, listenAddress string, filterFn bind.FilterFn) *TunNetstackDevice { +func NewNetstackDevice(name string, address WGAddress, wgPort int, key string, mtu int, iceBind *bind.ICEBind, listenAddress string) *TunNetstackDevice { return &TunNetstackDevice{ name: name, address: address, @@ -39,7 +38,7 @@ func NewNetstackDevice(name string, address WGAddress, wgPort int, key string, m key: key, mtu: mtu, listenAddress: listenAddress, - iceBind: bind.NewICEBind(transportNet, filterFn), + iceBind: iceBind, } } diff --git a/client/iface/device/device_usp_unix.go b/client/iface/device/device_usp_unix.go index 4175f65569..643d77565c 100644 --- a/client/iface/device/device_usp_unix.go +++ b/client/iface/device/device_usp_unix.go @@ -7,7 +7,6 @@ import ( "os" "runtime" - "github.com/pion/transport/v3" log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/tun" @@ -30,7 +29,7 @@ type USPDevice struct { configurer WGConfigurer } -func NewUSPDevice(name string, address WGAddress, port int, key string, mtu int, transportNet transport.Net, filterFn bind.FilterFn) *USPDevice { +func NewUSPDevice(name string, address WGAddress, port int, key string, mtu int, iceBind *bind.ICEBind) *USPDevice { log.Infof("using userspace bind mode") checkUser() @@ -41,7 +40,8 @@ func NewUSPDevice(name string, address WGAddress, port int, key string, mtu int, port: port, key: key, mtu: mtu, - iceBind: bind.NewICEBind(transportNet, filterFn)} + iceBind: iceBind, + } } func (t *USPDevice) Create() (WGConfigurer, error) { diff --git a/client/iface/iface.go b/client/iface/iface.go index accf5ce0af..b55143997e 100644 --- a/client/iface/iface.go +++ b/client/iface/iface.go @@ -6,12 +6,15 @@ import ( "sync" "time" + "github.com/hashicorp/go-multierror" log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + "github.com/netbirdio/netbird/client/errors" "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/device" + "github.com/netbirdio/netbird/client/iface/wgproxy" ) const ( @@ -28,8 +31,13 @@ type WGIface struct { userspaceBind bool mu sync.Mutex - configurer device.WGConfigurer - filter device.PacketFilter + configurer device.WGConfigurer + filter device.PacketFilter + wgProxyFactory *wgproxy.Factory +} + +func (w *WGIface) GetProxy() wgproxy.Proxy { + return w.wgProxyFactory.GetProxy() } // IsUserspaceBind indicates whether this interfaces is userspace with bind.ICEBind @@ -124,22 +132,26 @@ func (w *WGIface) Close() error { w.mu.Lock() defer w.mu.Unlock() - err := w.tun.Close() - if err != nil { - return fmt.Errorf("failed to close wireguard interface %s: %w", w.Name(), err) + var result *multierror.Error + + if err := w.wgProxyFactory.Free(); err != nil { + result = multierror.Append(result, fmt.Errorf("failed to free WireGuard proxy: %w", err)) } - err = w.waitUntilRemoved() - if err != nil { + if err := w.tun.Close(); err != nil { + result = multierror.Append(result, fmt.Errorf("failed to close wireguard interface %s: %w", w.Name(), err)) + } + + if err := w.waitUntilRemoved(); err != nil { log.Warnf("failed to remove WireGuard interface %s: %v", w.Name(), err) - err = w.Destroy() - if err != nil { - return fmt.Errorf("failed to remove WireGuard interface %s: %w", w.Name(), err) + if err := w.Destroy(); err != nil { + result = multierror.Append(result, fmt.Errorf("failed to remove WireGuard interface %s: %w", w.Name(), err)) + return errors.FormatErrorOrNil(result) } log.Infof("interface %s successfully removed", w.Name()) } - return nil + return errors.FormatErrorOrNil(result) } // SetFilter sets packet filters for the userspace implementation diff --git a/client/iface/iface_darwin.go b/client/iface/iface_darwin.go index b46ea0f806..898311e11d 100644 --- a/client/iface/iface_darwin.go +++ b/client/iface/iface_darwin.go @@ -12,6 +12,7 @@ import ( "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/netstack" + "github.com/netbirdio/netbird/client/iface/wgproxy" ) // NewWGIFace Creates a new WireGuard interface instance @@ -21,16 +22,19 @@ func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, return nil, err } + iceBind := bind.NewICEBind(transportNet, filterFn) + wgIFace := &WGIface{ - userspaceBind: true, + userspaceBind: true, + wgProxyFactory: wgproxy.NewFactory(wgPort, iceBind), } if netstack.IsEnabled() { - wgIFace.tun = device.NewNetstackDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, netstack.ListenAddr(), filterFn) + wgIFace.tun = device.NewNetstackDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, iceBind, netstack.ListenAddr()) return wgIFace, nil } - wgIFace.tun = device.NewTunDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, filterFn) + wgIFace.tun = device.NewTunDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, iceBind) return wgIFace, nil } diff --git a/client/iface/iface_moc.go b/client/iface/iface_moc.go index 703da9ce00..d91a7224ff 100644 --- a/client/iface/iface_moc.go +++ b/client/iface/iface_moc.go @@ -9,6 +9,7 @@ import ( "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/device" + "github.com/netbirdio/netbird/client/iface/wgproxy" ) type MockWGIface struct { @@ -30,6 +31,7 @@ type MockWGIface struct { GetDeviceFunc func() *device.FilteredDevice GetStatsFunc func(peerKey string) (configurer.WGStats, error) GetInterfaceGUIDStringFunc func() (string, error) + GetProxyFunc func() wgproxy.Proxy } func (m *MockWGIface) GetInterfaceGUIDString() (string, error) { @@ -103,3 +105,8 @@ func (m *MockWGIface) GetDevice() *device.FilteredDevice { func (m *MockWGIface) GetStats(peerKey string) (configurer.WGStats, error) { return m.GetStatsFunc(peerKey) } + +func (m *MockWGIface) GetProxy() wgproxy.Proxy { + //TODO implement me + panic("implement me") +} diff --git a/client/iface/iface_unix.go b/client/iface/iface_unix.go index 09dbb2c1f7..2ad1953af3 100644 --- a/client/iface/iface_unix.go +++ b/client/iface/iface_unix.go @@ -10,7 +10,7 @@ import ( "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/device" - "github.com/netbirdio/netbird/client/iface/netstack" + "github.com/netbirdio/netbird/client/iface/wgproxy" ) // NewWGIFace Creates a new WireGuard interface instance @@ -22,25 +22,31 @@ func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, wgIFace := &WGIface{} - // move the kernel/usp/netstack preference evaluation to upper layer - if netstack.IsEnabled() { - wgIFace.tun = device.NewNetstackDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, netstack.ListenAddr(), filterFn) + /* + if netstack.IsEnabled() { + iceBind := bind.NewICEBind(transportNet, filterFn) + wgIFace.tun = device.NewNetstackDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, iceBind, netstack.ListenAddr()) + wgIFace.userspaceBind = true + wgIFace.wgProxyFactory = wgproxy.NewFactory(wgPort, iceBind) + return wgIFace, nil + } + + if device.WireGuardModuleIsLoaded() { + wgIFace.tun = device.NewKernelDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet) + wgIFace.userspaceBind = false + wgIFace.wgProxyFactory = wgproxy.NewFactory(wgPort, nil) + return wgIFace, nil + } + */ + if device.ModuleTunIsLoaded() { + iceBind := bind.NewICEBind(transportNet, filterFn) + wgIFace.tun = device.NewUSPDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, iceBind) wgIFace.userspaceBind = true + wgIFace.wgProxyFactory = wgproxy.NewFactory(wgPort, iceBind) return wgIFace, nil } - if device.WireGuardModuleIsLoaded() { - wgIFace.tun = device.NewKernelDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet) - wgIFace.userspaceBind = false - return wgIFace, nil - } - - if !device.ModuleTunIsLoaded() { - return nil, fmt.Errorf("couldn't check or load tun module") - } - wgIFace.tun = device.NewUSPDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, nil) - wgIFace.userspaceBind = true - return wgIFace, nil + return nil, fmt.Errorf("couldn't check or load tun module") } // CreateOnAndroid this function make sense on mobile only diff --git a/client/iface/iface_windows.go b/client/iface/iface_windows.go index 6845ef3ddd..08d5033be6 100644 --- a/client/iface/iface_windows.go +++ b/client/iface/iface_windows.go @@ -8,6 +8,7 @@ import ( "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/netstack" + "github.com/netbirdio/netbird/client/iface/wgproxy" ) // NewWGIFace Creates a new WireGuard interface instance @@ -17,12 +18,15 @@ func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, return nil, err } + iceBind := bind.NewICEBind(transportNet, filterFn) + wgIFace := &WGIface{ - userspaceBind: true, + userspaceBind: true, + wgProxyFactory: wgproxy.NewFactory(wgPort, iceBind), } if netstack.IsEnabled() { - wgIFace.tun = device.NewNetstackDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, netstack.ListenAddr(), filterFn) + wgIFace.tun = device.NewNetstackDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, iceBind, netstack.ListenAddr()) return wgIFace, nil } diff --git a/client/iface/iwginterface.go b/client/iface/iwginterface.go index cb6d7ccd9a..f5ab295390 100644 --- a/client/iface/iwginterface.go +++ b/client/iface/iwginterface.go @@ -11,6 +11,7 @@ import ( "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/device" + "github.com/netbirdio/netbird/client/iface/wgproxy" ) type IWGIface interface { @@ -22,6 +23,7 @@ type IWGIface interface { ToInterface() *net.Interface Up() (*bind.UniversalUDPMuxDefault, error) UpdateAddr(newAddr string) error + GetProxy() wgproxy.Proxy UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error RemovePeer(peerKey string) error AddAllowedIP(peerKey string, allowedIP string) error diff --git a/client/iface/iwginterface_windows.go b/client/iface/iwginterface_windows.go index 6baeb66ae0..96eec52a50 100644 --- a/client/iface/iwginterface_windows.go +++ b/client/iface/iwginterface_windows.go @@ -9,6 +9,7 @@ import ( "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/device" + "github.com/netbirdio/netbird/client/iface/wgproxy" ) type IWGIface interface { @@ -20,6 +21,7 @@ type IWGIface interface { ToInterface() *net.Interface Up() (*bind.UniversalUDPMuxDefault, error) UpdateAddr(newAddr string) error + GetProxy() wgproxy.Proxy UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error RemovePeer(peerKey string) error AddAllowedIP(peerKey string, allowedIP string) error diff --git a/client/iface/wgproxy/bind/proxy.go b/client/iface/wgproxy/bind/proxy.go new file mode 100644 index 0000000000..e986d6d7b0 --- /dev/null +++ b/client/iface/wgproxy/bind/proxy.go @@ -0,0 +1,137 @@ +package bind + +import ( + "context" + "fmt" + "net" + "net/netip" + "sync" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/iface/bind" +) + +type ProxyBind struct { + Bind *bind.ICEBind + + wgAddr *net.UDPAddr + wgEndpoint *bind.Endpoint + remoteConn net.Conn + ctx context.Context + cancel context.CancelFunc + closeMu sync.Mutex + closed bool + + pausedMu sync.Mutex + paused bool + isStarted bool +} + +// AddTurnConn adds a new connection to the bind. +// endpoint is the NetBird address of the remote peer. The SetEndpoint return with the address what will be used in the +// WireGuard configuration. +func (p *ProxyBind) AddTurnConn(ctx context.Context, nbAddr *net.UDPAddr, remoteConn net.Conn) error { + addr, err := p.Bind.SetEndpoint(nbAddr, remoteConn) + if err != nil { + return err + } + + p.wgAddr = addr + p.wgEndpoint = addrToEndpoint(addr) + p.remoteConn = remoteConn + p.ctx, p.cancel = context.WithCancel(ctx) + return err + +} +func (p *ProxyBind) EndpointAddr() *net.UDPAddr { + return p.wgAddr +} + +func (p *ProxyBind) Work() { + if p.remoteConn == nil { + return + } + + p.pausedMu.Lock() + p.paused = false + p.pausedMu.Unlock() + + // Start the proxy only once + if !p.isStarted { + p.isStarted = true + go p.proxyToLocal(p.ctx) + } +} + +func (p *ProxyBind) Pause() { + if p.remoteConn == nil { + return + } + + p.pausedMu.Lock() + p.paused = true + p.pausedMu.Unlock() +} + +func (p *ProxyBind) CloseConn() error { + if p.cancel == nil { + return fmt.Errorf("proxy not started") + } + return p.close() +} + +func (p *ProxyBind) close() error { + p.closeMu.Lock() + defer p.closeMu.Unlock() + + if p.closed { + return nil + } + p.closed = true + + p.cancel() + + p.Bind.RemoveEndpoint(p.wgAddr) + + return p.remoteConn.Close() +} + +func (p *ProxyBind) proxyToLocal(ctx context.Context) { + defer func() { + if err := p.close(); err != nil { + log.Warnf("failed to close remote conn: %s", err) + } + }() + + buf := make([]byte, 1500) + for { + n, err := p.remoteConn.Read(buf) + if err != nil { + if ctx.Err() != nil { + return + } + log.Errorf("failed to read from remote conn: %s, %s", p.remoteConn.RemoteAddr(), err) + return + } + + p.pausedMu.Lock() + if p.paused { + p.pausedMu.Unlock() + continue + } + + msg := bind.RecvMessage{ + Endpoint: p.wgEndpoint, + Buffer: buf[:n], + } + p.Bind.RecvChan <- msg + p.pausedMu.Unlock() + } +} + +func addrToEndpoint(addr *net.UDPAddr) *bind.Endpoint { + ip, _ := netip.AddrFromSlice(addr.IP.To4()) + addrPort := netip.AddrPortFrom(ip, uint16(addr.Port)) + return &bind.Endpoint{AddrPort: addrPort} +} diff --git a/client/internal/wgproxy/ebpf/portlookup.go b/client/iface/wgproxy/ebpf/portlookup.go similarity index 100% rename from client/internal/wgproxy/ebpf/portlookup.go rename to client/iface/wgproxy/ebpf/portlookup.go diff --git a/client/internal/wgproxy/ebpf/portlookup_test.go b/client/iface/wgproxy/ebpf/portlookup_test.go similarity index 100% rename from client/internal/wgproxy/ebpf/portlookup_test.go rename to client/iface/wgproxy/ebpf/portlookup_test.go diff --git a/client/internal/wgproxy/ebpf/proxy.go b/client/iface/wgproxy/ebpf/proxy.go similarity index 99% rename from client/internal/wgproxy/ebpf/proxy.go rename to client/iface/wgproxy/ebpf/proxy.go index e850f4533c..e21fc35d4e 100644 --- a/client/internal/wgproxy/ebpf/proxy.go +++ b/client/iface/wgproxy/ebpf/proxy.go @@ -119,7 +119,7 @@ func (p *WGEBPFProxy) Free() error { p.ctxCancel() var result *multierror.Error - if p.conn != nil { // p.conn will be nil if we have failed to listen + if p.conn != nil { if err := p.conn.Close(); err != nil { result = multierror.Append(result, err) } diff --git a/client/internal/wgproxy/ebpf/proxy_test.go b/client/iface/wgproxy/ebpf/proxy_test.go similarity index 100% rename from client/internal/wgproxy/ebpf/proxy_test.go rename to client/iface/wgproxy/ebpf/proxy_test.go diff --git a/client/internal/wgproxy/ebpf/wrapper.go b/client/iface/wgproxy/ebpf/wrapper.go similarity index 95% rename from client/internal/wgproxy/ebpf/wrapper.go rename to client/iface/wgproxy/ebpf/wrapper.go index b6a8ac4522..efd5fd946c 100644 --- a/client/internal/wgproxy/ebpf/wrapper.go +++ b/client/iface/wgproxy/ebpf/wrapper.go @@ -28,7 +28,7 @@ type ProxyWrapper struct { isStarted bool } -func (p *ProxyWrapper) AddTurnConn(ctx context.Context, remoteConn net.Conn) error { +func (p *ProxyWrapper) AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, remoteConn net.Conn) error { addr, err := p.WgeBPFProxy.AddTurnConn(remoteConn) if err != nil { return fmt.Errorf("add turn conn: %w", err) diff --git a/client/iface/wgproxy/factory_linux.go b/client/iface/wgproxy/factory_linux.go new file mode 100644 index 0000000000..d75dd8b2d9 --- /dev/null +++ b/client/iface/wgproxy/factory_linux.go @@ -0,0 +1,76 @@ +//go:build !android + +package wgproxy + +import ( + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/iface/bind" + proxyBind "github.com/netbirdio/netbird/client/iface/wgproxy/bind" + "github.com/netbirdio/netbird/client/iface/wgproxy/ebpf" + udpProxy "github.com/netbirdio/netbird/client/iface/wgproxy/udp" +) + +type proxyMode int + +const ( + proxyModeUDP proxyMode = iota + proxyModeEBPF + proxyModeBind +) + +type Factory struct { + wgPort int + mode proxyMode + + ebpfProxy *ebpf.WGEBPFProxy + bind *bind.ICEBind +} + +func NewFactory(wgPort int, iceBind *bind.ICEBind) *Factory { + f := &Factory{ + wgPort: wgPort, + } + + if iceBind != nil { + f.bind = iceBind + f.mode = proxyModeBind + return f + } + + ebpfProxy := ebpf.NewWGEBPFProxy(wgPort) + if err := ebpfProxy.Listen(); err != nil { + log.Warnf("failed to initialize ebpf proxy, fallback to user space proxy: %s", err) + f.mode = proxyModeUDP + return f + } + f.ebpfProxy = ebpfProxy + f.mode = proxyModeEBPF + return f +} + +func (w *Factory) GetProxy() Proxy { + switch w.mode { + case proxyModeUDP: + return udpProxy.NewWGUDPProxy(w.wgPort) + case proxyModeEBPF: + p := &ebpf.ProxyWrapper{ + WgeBPFProxy: w.ebpfProxy, + } + return p + case proxyModeBind: + p := &proxyBind.ProxyBind{ + Bind: w.bind, + } + return p + default: + return nil + } +} + +func (w *Factory) Free() error { + if w.ebpfProxy == nil { + return nil + } + return w.ebpfProxy.Free() +} diff --git a/client/iface/wgproxy/factory_nonlinux.go b/client/iface/wgproxy/factory_nonlinux.go new file mode 100644 index 0000000000..f60ac8436e --- /dev/null +++ b/client/iface/wgproxy/factory_nonlinux.go @@ -0,0 +1,30 @@ +//go:build !linux || android + +package wgproxy + +import ( + "github.com/netbirdio/netbird/client/iface/bind" + proxyBind "github.com/netbirdio/netbird/client/iface/wgproxy/bind" +) + +type Factory struct { + bind *bind.ICEBind + port int +} + +func NewFactory(port int, bind *bind.ICEBind) *Factory { + return &Factory{ + port: port, + bind: bind, + } +} + +func (w *Factory) GetProxy() Proxy { + return &proxyBind.ProxyBind{ + Bind: w.bind, + } +} + +func (w *Factory) Free() error { + return nil +} diff --git a/client/iface/wgproxy/proxy.go b/client/iface/wgproxy/proxy.go new file mode 100644 index 0000000000..243aa2bd2a --- /dev/null +++ b/client/iface/wgproxy/proxy.go @@ -0,0 +1,15 @@ +package wgproxy + +import ( + "context" + "net" +) + +// Proxy is a transfer layer between the relayed connection and the WireGuard +type Proxy interface { + AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, remoteConn net.Conn) error + EndpointAddr() *net.UDPAddr // EndpointAddr returns the address of the WireGuard peer endpoint + Work() // Work start or resume the proxy + Pause() // Pause to forward the packages from remote connection to WireGuard. The opposite way still works. + CloseConn() error +} diff --git a/client/internal/wgproxy/proxy_test.go b/client/iface/wgproxy/proxy_test.go similarity index 90% rename from client/internal/wgproxy/proxy_test.go rename to client/iface/wgproxy/proxy_test.go index b88ff3f83c..64b6176211 100644 --- a/client/internal/wgproxy/proxy_test.go +++ b/client/iface/wgproxy/proxy_test.go @@ -11,8 +11,8 @@ import ( "testing" "time" - "github.com/netbirdio/netbird/client/internal/wgproxy/ebpf" - "github.com/netbirdio/netbird/client/internal/wgproxy/usp" + "github.com/netbirdio/netbird/client/iface/wgproxy/ebpf" + udpProxy "github.com/netbirdio/netbird/client/iface/wgproxy/udp" "github.com/netbirdio/netbird/util" ) @@ -84,7 +84,7 @@ func TestProxyCloseByRemoteConn(t *testing.T) { }{ { name: "userspace proxy", - proxy: usp.NewWGUserSpaceProxy(51830), + proxy: udpProxy.NewWGUDPProxy(51830), }, } @@ -114,7 +114,7 @@ func TestProxyCloseByRemoteConn(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { relayedConn := newMockConn() - err := tt.proxy.AddTurnConn(ctx, relayedConn) + err := tt.proxy.AddTurnConn(ctx, nil, relayedConn) if err != nil { t.Errorf("error: %v", err) } diff --git a/client/internal/wgproxy/usp/proxy.go b/client/iface/wgproxy/udp/proxy.go similarity index 83% rename from client/internal/wgproxy/usp/proxy.go rename to client/iface/wgproxy/udp/proxy.go index f73500717a..8bee099014 100644 --- a/client/internal/wgproxy/usp/proxy.go +++ b/client/iface/wgproxy/udp/proxy.go @@ -1,4 +1,4 @@ -package usp +package udp import ( "context" @@ -12,8 +12,8 @@ import ( "github.com/netbirdio/netbird/client/errors" ) -// WGUserSpaceProxy proxies -type WGUserSpaceProxy struct { +// WGUDPProxy proxies +type WGUDPProxy struct { localWGListenPort int remoteConn net.Conn @@ -28,10 +28,10 @@ type WGUserSpaceProxy struct { isStarted bool } -// NewWGUserSpaceProxy instantiate a user space WireGuard proxy. This is not a thread safe implementation -func NewWGUserSpaceProxy(wgPort int) *WGUserSpaceProxy { +// NewWGUDPProxy instantiate a UDP based WireGuard proxy. This is not a thread safe implementation +func NewWGUDPProxy(wgPort int) *WGUDPProxy { log.Debugf("Initializing new user space proxy with port %d", wgPort) - p := &WGUserSpaceProxy{ + p := &WGUDPProxy{ localWGListenPort: wgPort, } return p @@ -42,7 +42,7 @@ func NewWGUserSpaceProxy(wgPort int) *WGUserSpaceProxy { // the connection is complete, an error is returned. Once successfully // connected, any expiration of the context will not affect the // connection. -func (p *WGUserSpaceProxy) AddTurnConn(ctx context.Context, remoteConn net.Conn) error { +func (p *WGUDPProxy) AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, remoteConn net.Conn) error { dialer := net.Dialer{} localConn, err := dialer.DialContext(ctx, "udp", fmt.Sprintf(":%d", p.localWGListenPort)) if err != nil { @@ -57,7 +57,7 @@ func (p *WGUserSpaceProxy) AddTurnConn(ctx context.Context, remoteConn net.Conn) return err } -func (p *WGUserSpaceProxy) EndpointAddr() *net.UDPAddr { +func (p *WGUDPProxy) EndpointAddr() *net.UDPAddr { if p.localConn == nil { return nil } @@ -66,7 +66,7 @@ func (p *WGUserSpaceProxy) EndpointAddr() *net.UDPAddr { } // Work starts the proxy or resumes it if it was paused -func (p *WGUserSpaceProxy) Work() { +func (p *WGUDPProxy) Work() { if p.remoteConn == nil { return } @@ -83,7 +83,7 @@ func (p *WGUserSpaceProxy) Work() { } // Pause pauses the proxy from receiving data from the remote peer -func (p *WGUserSpaceProxy) Pause() { +func (p *WGUDPProxy) Pause() { if p.remoteConn == nil { return } @@ -94,14 +94,14 @@ func (p *WGUserSpaceProxy) Pause() { } // CloseConn close the localConn -func (p *WGUserSpaceProxy) CloseConn() error { +func (p *WGUDPProxy) CloseConn() error { if p.cancel == nil { return fmt.Errorf("proxy not started") } return p.close() } -func (p *WGUserSpaceProxy) close() error { +func (p *WGUDPProxy) close() error { p.closeMu.Lock() defer p.closeMu.Unlock() @@ -125,7 +125,7 @@ func (p *WGUserSpaceProxy) close() error { } // proxyToRemote proxies from Wireguard to the RemoteKey -func (p *WGUserSpaceProxy) proxyToRemote(ctx context.Context) { +func (p *WGUDPProxy) proxyToRemote(ctx context.Context) { defer func() { if err := p.close(); err != nil { log.Warnf("error in proxy to remote loop: %s", err) @@ -157,7 +157,7 @@ func (p *WGUserSpaceProxy) proxyToRemote(ctx context.Context) { // proxyToLocal proxies from the Remote peer to local WireGuard // if the proxy is paused it will drain the remote conn and drop the packets -func (p *WGUserSpaceProxy) proxyToLocal(ctx context.Context) { +func (p *WGUDPProxy) proxyToLocal(ctx context.Context) { defer func() { if err := p.close(); err != nil { log.Warnf("error in proxy to local loop: %s", err) diff --git a/client/internal/engine.go b/client/internal/engine.go index eac8ec098f..76398ef344 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -35,7 +35,6 @@ import ( "github.com/netbirdio/netbird/client/internal/rosenpass" "github.com/netbirdio/netbird/client/internal/routemanager" "github.com/netbirdio/netbird/client/internal/routemanager/systemops" - "github.com/netbirdio/netbird/client/internal/wgproxy" nbssh "github.com/netbirdio/netbird/client/ssh" "github.com/netbirdio/netbird/client/system" nbdns "github.com/netbirdio/netbird/dns" @@ -141,8 +140,7 @@ type Engine struct { ctx context.Context cancel context.CancelFunc - wgInterface iface.IWGIface - wgProxyFactory *wgproxy.Factory + wgInterface iface.IWGIface udpMux *bind.UniversalUDPMuxDefault @@ -299,9 +297,6 @@ func (e *Engine) Start() error { } e.wgInterface = wgIface - userspace := e.wgInterface.IsUserspaceBind() - e.wgProxyFactory = wgproxy.NewFactory(userspace, e.config.WgPort) - if e.config.RosenpassEnabled { log.Infof("rosenpass is enabled") if e.config.RosenpassPermissive { @@ -966,7 +961,7 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, e }, } - peerConn, err := peer.NewConn(e.ctx, config, e.statusRecorder, e.wgProxyFactory, e.signaler, e.mobileDep.IFaceDiscover, e.relayManager) + peerConn, err := peer.NewConn(e.ctx, config, e.statusRecorder, e.signaler, e.mobileDep.IFaceDiscover, e.relayManager) if err != nil { return nil, err } @@ -1117,12 +1112,6 @@ func (e *Engine) parseNATExternalIPMappings() []string { } func (e *Engine) close() { - if e.wgProxyFactory != nil { - if err := e.wgProxyFactory.Free(); err != nil { - log.Errorf("failed closing ebpf proxy: %s", err) - } - } - log.Debugf("removing Netbird interface %s", e.config.WgIfaceName) if e.wgInterface != nil { if err := e.wgInterface.Close(); err != nil { diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index 1b740388d9..99acfde314 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -17,8 +17,8 @@ import ( "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/configurer" + "github.com/netbirdio/netbird/client/iface/wgproxy" "github.com/netbirdio/netbird/client/internal/stdnet" - "github.com/netbirdio/netbird/client/internal/wgproxy" relayClient "github.com/netbirdio/netbird/relay/client" "github.com/netbirdio/netbird/route" nbnet "github.com/netbirdio/netbird/util/net" @@ -81,11 +81,10 @@ type Conn struct { ctxCancel context.CancelFunc config ConnConfig statusRecorder *Status - wgProxyFactory *wgproxy.Factory signaler *Signaler - iFaceDiscover stdnet.ExternalIFaceDiscover relayManager *relayClient.Manager - allowedIPsIP string + allowedIP net.IP + allowedNet string handshaker *Handshaker onConnected func(remoteWireGuardKey string, remoteRosenpassPubKey []byte, wireGuardIP string, remoteRosenpassAddr string) @@ -116,8 +115,8 @@ type Conn struct { // NewConn creates a new not opened Conn to the remote peer. // To establish a connection run Conn.Open -func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Status, wgProxyFactory *wgproxy.Factory, signaler *Signaler, iFaceDiscover stdnet.ExternalIFaceDiscover, relayManager *relayClient.Manager) (*Conn, error) { - _, allowedIPsIP, err := net.ParseCIDR(config.WgConfig.AllowedIps) +func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Status, signaler *Signaler, iFaceDiscover stdnet.ExternalIFaceDiscover, relayManager *relayClient.Manager) (*Conn, error) { + allowedIP, allowedNet, err := net.ParseCIDR(config.WgConfig.AllowedIps) if err != nil { log.Errorf("failed to parse allowedIPS: %v", err) return nil, err @@ -127,19 +126,17 @@ func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Statu connLog := log.WithField("peer", config.Key) var conn = &Conn{ - log: connLog, - ctx: ctx, - ctxCancel: ctxCancel, - config: config, - statusRecorder: statusRecorder, - wgProxyFactory: wgProxyFactory, - signaler: signaler, - iFaceDiscover: iFaceDiscover, - relayManager: relayManager, - allowedIPsIP: allowedIPsIP.String(), - statusRelay: NewAtomicConnStatus(), - statusICE: NewAtomicConnStatus(), - + log: connLog, + ctx: ctx, + ctxCancel: ctxCancel, + config: config, + statusRecorder: statusRecorder, + signaler: signaler, + relayManager: relayManager, + allowedIP: allowedIP, + allowedNet: allowedNet.String(), + statusRelay: NewAtomicConnStatus(), + statusICE: NewAtomicConnStatus(), iCEDisconnected: make(chan bool, 1), relayDisconnected: make(chan bool, 1), } @@ -692,7 +689,7 @@ func (conn *Conn) doOnConnected(remoteRosenpassPubKey []byte, remoteRosenpassAdd } if conn.onConnected != nil { - conn.onConnected(conn.config.Key, remoteRosenpassPubKey, conn.allowedIPsIP, remoteRosenpassAddr) + conn.onConnected(conn.config.Key, remoteRosenpassPubKey, conn.allowedNet, remoteRosenpassAddr) } } @@ -783,8 +780,13 @@ func (conn *Conn) freeUpConnID() { func (conn *Conn) newProxy(remoteConn net.Conn) (wgproxy.Proxy, error) { conn.log.Debugf("setup proxied WireGuard connection") - wgProxy := conn.wgProxyFactory.GetProxy() - if err := wgProxy.AddTurnConn(conn.ctx, remoteConn); err != nil { + udpAddr := &net.UDPAddr{ + IP: conn.allowedIP, + Port: conn.config.WgConfig.WgListenPort, + } + + wgProxy := conn.config.WgConfig.WgInterface.GetProxy() + if err := wgProxy.AddTurnConn(conn.ctx, udpAddr, remoteConn); err != nil { conn.log.Errorf("failed to add turn net.Conn to local proxy: %v", err) return nil, err } diff --git a/client/internal/peer/conn_test.go b/client/internal/peer/conn_test.go index b4926a9d2e..e68861c5f0 100644 --- a/client/internal/peer/conn_test.go +++ b/client/internal/peer/conn_test.go @@ -11,7 +11,6 @@ import ( "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/internal/stdnet" - "github.com/netbirdio/netbird/client/internal/wgproxy" "github.com/netbirdio/netbird/util" ) @@ -44,11 +43,7 @@ func TestNewConn_interfaceFilter(t *testing.T) { } func TestConn_GetKey(t *testing.T) { - wgProxyFactory := wgproxy.NewFactory(false, connConf.LocalWgPort) - defer func() { - _ = wgProxyFactory.Free() - }() - conn, err := NewConn(context.Background(), connConf, nil, wgProxyFactory, nil, nil, nil) + conn, err := NewConn(context.Background(), connConf, nil, nil, nil, nil) if err != nil { return } @@ -59,11 +54,7 @@ func TestConn_GetKey(t *testing.T) { } func TestConn_OnRemoteOffer(t *testing.T) { - wgProxyFactory := wgproxy.NewFactory(false, connConf.LocalWgPort) - defer func() { - _ = wgProxyFactory.Free() - }() - conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), wgProxyFactory, nil, nil, nil) + conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), nil, nil, nil) if err != nil { return } @@ -96,11 +87,7 @@ func TestConn_OnRemoteOffer(t *testing.T) { } func TestConn_OnRemoteAnswer(t *testing.T) { - wgProxyFactory := wgproxy.NewFactory(false, connConf.LocalWgPort) - defer func() { - _ = wgProxyFactory.Free() - }() - conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), wgProxyFactory, nil, nil, nil) + conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), nil, nil, nil) if err != nil { return } @@ -132,11 +119,7 @@ func TestConn_OnRemoteAnswer(t *testing.T) { wg.Wait() } func TestConn_Status(t *testing.T) { - wgProxyFactory := wgproxy.NewFactory(false, connConf.LocalWgPort) - defer func() { - _ = wgProxyFactory.Free() - }() - conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), wgProxyFactory, nil, nil, nil) + conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), nil, nil, nil) if err != nil { return } diff --git a/client/internal/wgproxy/factory_linux.go b/client/internal/wgproxy/factory_linux.go deleted file mode 100644 index 369ba99db1..0000000000 --- a/client/internal/wgproxy/factory_linux.go +++ /dev/null @@ -1,50 +0,0 @@ -//go:build !android - -package wgproxy - -import ( - log "github.com/sirupsen/logrus" - - "github.com/netbirdio/netbird/client/internal/wgproxy/ebpf" - "github.com/netbirdio/netbird/client/internal/wgproxy/usp" -) - -type Factory struct { - wgPort int - ebpfProxy *ebpf.WGEBPFProxy -} - -func NewFactory(userspace bool, wgPort int) *Factory { - f := &Factory{wgPort: wgPort} - - if userspace { - return f - } - - ebpfProxy := ebpf.NewWGEBPFProxy(wgPort) - err := ebpfProxy.Listen() - if err != nil { - log.Warnf("failed to initialize ebpf proxy, fallback to user space proxy: %s", err) - return f - } - - f.ebpfProxy = ebpfProxy - return f -} - -func (w *Factory) GetProxy() Proxy { - if w.ebpfProxy != nil { - p := &ebpf.ProxyWrapper{ - WgeBPFProxy: w.ebpfProxy, - } - return p - } - return usp.NewWGUserSpaceProxy(w.wgPort) -} - -func (w *Factory) Free() error { - if w.ebpfProxy == nil { - return nil - } - return w.ebpfProxy.Free() -} diff --git a/client/internal/wgproxy/factory_nonlinux.go b/client/internal/wgproxy/factory_nonlinux.go deleted file mode 100644 index f930b09b3a..0000000000 --- a/client/internal/wgproxy/factory_nonlinux.go +++ /dev/null @@ -1,21 +0,0 @@ -//go:build !linux || android - -package wgproxy - -import "github.com/netbirdio/netbird/client/internal/wgproxy/usp" - -type Factory struct { - wgPort int -} - -func NewFactory(_ bool, wgPort int) *Factory { - return &Factory{wgPort: wgPort} -} - -func (w *Factory) GetProxy() Proxy { - return usp.NewWGUserSpaceProxy(w.wgPort) -} - -func (w *Factory) Free() error { - return nil -} diff --git a/client/internal/wgproxy/proxy.go b/client/internal/wgproxy/proxy.go deleted file mode 100644 index 558121cdd5..0000000000 --- a/client/internal/wgproxy/proxy.go +++ /dev/null @@ -1,15 +0,0 @@ -package wgproxy - -import ( - "context" - "net" -) - -// Proxy is a transfer layer between the relayed connection and the WireGuard -type Proxy interface { - AddTurnConn(ctx context.Context, turnConn net.Conn) error - EndpointAddr() *net.UDPAddr - Work() - Pause() - CloseConn() error -}