Skip to content

Commit

Permalink
global: switch to netip
Browse files Browse the repository at this point in the history
Signed-off-by: Jason A. Donenfeld <[email protected]>
  • Loading branch information
zx2c4 committed Nov 2, 2021
1 parent 539979e commit 4776166
Show file tree
Hide file tree
Showing 17 changed files with 284 additions and 356 deletions.
78 changes: 33 additions & 45 deletions conf/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@ import (
"crypto/rand"
"crypto/subtle"
"encoding/base64"
"errors"
"fmt"
"net"
"strings"
"time"

"golang.zx2c4.com/go118/netip"

"golang.org/x/crypto/curve25519"

"golang.zx2c4.com/wireguard/windows/l18n"
Expand All @@ -22,8 +24,7 @@ import (
const KeyLength = 32

type IPCidr struct {
IP net.IP
Cidr uint8
netip.Prefix
}

type Endpoint struct {
Expand All @@ -46,7 +47,7 @@ type Interface struct {
Addresses []IPCidr
ListenPort uint16
MTU uint16
DNS []net.IP
DNS []netip.Addr
DNSSearch []string
PreUp string
PostUp string
Expand All @@ -67,62 +68,28 @@ type Peer struct {
LastHandshakeTime HandshakeTime
}

func (r *IPCidr) String() string {
return fmt.Sprintf("%s/%d", r.IP.String(), r.Cidr)
}

func (r *IPCidr) Bits() uint8 {
if r.IP.To4() != nil {
return 32
}
return 128
}

func (r *IPCidr) IPNet() net.IPNet {
return net.IPNet{
IP: r.IP,
Mask: net.CIDRMask(int(r.Cidr), int(r.Bits())),
}
}

func (r *IPCidr) MaskSelf() {
bits := int(r.Bits())
mask := net.CIDRMask(int(r.Cidr), bits)
for i := 0; i < bits/8; i++ {
r.IP[i] &= mask[i]
}
}

func (conf *Config) IntersectsWith(other *Config) bool {
type hashableIPCidr struct {
ip string
cidr byte
}
allRoutes := make(map[hashableIPCidr]bool, len(conf.Interface.Addresses)*2+len(conf.Peers)*3)
allRoutes := make(map[netip.Prefix]bool, len(conf.Interface.Addresses)*2+len(conf.Peers)*3)
for _, a := range conf.Interface.Addresses {
allRoutes[hashableIPCidr{string(a.IP), byte(len(a.IP) * 8)}] = true
a.MaskSelf()
allRoutes[hashableIPCidr{string(a.IP), a.Cidr}] = true
allRoutes[netip.PrefixFrom(a.Addr(), a.Addr().BitLen())] = true
allRoutes[a.Masked()] = true
}
for i := range conf.Peers {
for _, a := range conf.Peers[i].AllowedIPs {
a.MaskSelf()
allRoutes[hashableIPCidr{string(a.IP), a.Cidr}] = true
allRoutes[a.Masked()] = true
}
}
for _, a := range other.Interface.Addresses {
if allRoutes[hashableIPCidr{string(a.IP), byte(len(a.IP) * 8)}] {
if allRoutes[netip.PrefixFrom(a.Addr(), a.Addr().BitLen())] {
return true
}
a.MaskSelf()
if allRoutes[hashableIPCidr{string(a.IP), a.Cidr}] {
if allRoutes[a.Masked()] {
return true
}
}
for i := range other.Peers {
for _, a := range other.Peers[i].AllowedIPs {
a.MaskSelf()
if allRoutes[hashableIPCidr{string(a.IP), a.Cidr}] {
if allRoutes[a.Masked()] {
return true
}
}
Expand Down Expand Up @@ -233,6 +200,27 @@ func (b Bytes) String() string {
return l18n.Sprintf("%.2f\u00a0TiB", float64(b)/(1024*1024*1024)/1024)
}

func (p IPCidr) MarshalBinary() ([]byte, error) {
b, err := p.Addr().MarshalBinary()
if err != nil {
return nil, err
}
return append(b, uint8(p.Bits())), nil
}

func (p *IPCidr) UnmarshalBinary(b []byte) error {
if len(b) < 1 {
return errors.New("unexpected byte slice")
}
var addr netip.Addr
err := addr.UnmarshalBinary(b[:len(b)-1])
if err != nil {
return err
}
*p = IPCidr{netip.PrefixFrom(addr, int(b[len(b)-1]))}
return nil
}

func (conf *Config) DeduplicateNetworkEntries() {
m := make(map[string]bool, len(conf.Interface.Addresses))
i := 0
Expand Down
17 changes: 9 additions & 8 deletions conf/dnsresolver_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@ package conf
import (
"fmt"
"log"
"net"
"syscall"
"time"
"unsafe"

"golang.zx2c4.com/go118/netip"

"golang.org/x/sys/windows"
"golang.zx2c4.com/wireguard/windows/services"
)
Expand Down Expand Up @@ -66,24 +67,24 @@ func resolveHostnameOnce(name string) (resolvedIPString string, err error) {
return
}
defer windows.FreeAddrInfoW(result)
ipv6 := ""
var v6 netip.Addr
for ; result != nil; result = result.Next {
switch result.Family {
case windows.AF_INET:
return (net.IP)((*syscall.RawSockaddrInet4)(unsafe.Pointer(result.Addr)).Addr[:]).String(), nil
return netip.AddrFrom4((*syscall.RawSockaddrInet4)(unsafe.Pointer(result.Addr)).Addr).String(), nil
case windows.AF_INET6:
if len(ipv6) != 0 {
if v6.IsValid() {
continue
}
a := (*syscall.RawSockaddrInet6)(unsafe.Pointer(result.Addr))
ipv6 = (net.IP)(a.Addr[:]).String()
v6 = netip.AddrFrom16(a.Addr)
if a.Scope_id != 0 {
ipv6 += fmt.Sprintf("%%%d", a.Scope_id)
v6 = v6.WithZone(fmt.Sprint(a.Scope_id))
}
}
}
if len(ipv6) != 0 {
return ipv6, nil
if v6.IsValid() {
return v6.String(), nil
}
err = windows.WSAHOST_NOT_FOUND
return
Expand Down
73 changes: 22 additions & 51 deletions conf/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@ package conf

import (
"encoding/base64"
"net"
"strconv"
"strings"

"golang.zx2c4.com/go118/netip"

"golang.org/x/sys/windows"
"golang.org/x/text/encoding/unicode"

Expand All @@ -27,43 +28,16 @@ func (e *ParseError) Error() string {
return l18n.Sprintf("%s: %q", e.why, e.offender)
}

func parseIPCidr(s string) (ipcidr *IPCidr, err error) {
var addrStr, cidrStr string
var cidr int

i := strings.IndexByte(s, '/')
if i < 0 {
addrStr = s
} else {
addrStr, cidrStr = s[:i], s[i+1:]
}

err = &ParseError{l18n.Sprintf("Invalid IP address"), s}
addr := net.ParseIP(addrStr)
if addr == nil {
return
}
maybeV4 := addr.To4()
if maybeV4 != nil {
addr = maybeV4
func parseIPCidr(s string) (IPCidr, error) {
ipcidr, err := netip.ParsePrefix(s)
if err == nil {
return IPCidr{ipcidr}, nil
}
if len(cidrStr) > 0 {
err = &ParseError{l18n.Sprintf("Invalid network prefix length"), s}
cidr, err = strconv.Atoi(cidrStr)
if err != nil || cidr < 0 || cidr > 128 {
return
}
if cidr > 32 && maybeV4 != nil {
return
}
} else {
if maybeV4 != nil {
cidr = 32
} else {
cidr = 128
}
addr, err := netip.ParseAddr(s)
if err != nil {
return IPCidr{}, &ParseError{l18n.Sprintf("Invalid IP address: "), s}
}
return &IPCidr{addr, uint8(cidr)}, nil
return IPCidr{netip.PrefixFrom(addr, addr.BitLen())}, nil
}

func parseEndpoint(s string) (*Endpoint, error) {
Expand All @@ -87,16 +61,16 @@ func parseEndpoint(s string) (*Endpoint, error) {
if i := strings.LastIndexByte(host, '%'); i > 1 {
end = i
}
maybeV6 := net.ParseIP(host[1:end])
if maybeV6 == nil || len(maybeV6) != net.IPv6len {
maybeV6, err2 := netip.ParseAddr(host[1:end])
if err2 != nil || !maybeV6.Is6() {
return nil, err
}
} else {
return nil, err
}
host = host[1 : len(host)-1]
}
return &Endpoint{host, uint16(port)}, nil
return &Endpoint{host, port}, nil
}

func parseMTU(s string) (uint16, error) {
Expand Down Expand Up @@ -256,16 +230,16 @@ func FromWgQuick(s string, name string) (*Config, error) {
if err != nil {
return nil, err
}
conf.Interface.Addresses = append(conf.Interface.Addresses, *a)
conf.Interface.Addresses = append(conf.Interface.Addresses, a)
}
case "dns":
addresses, err := splitList(val)
if err != nil {
return nil, err
}
for _, address := range addresses {
a := net.ParseIP(address)
if a == nil {
a, err := netip.ParseAddr(address)
if err != nil {
conf.Interface.DNSSearch = append(conf.Interface.DNSSearch, address)
} else {
conf.Interface.DNS = append(conf.Interface.DNS, a)
Expand Down Expand Up @@ -312,7 +286,7 @@ func FromWgQuick(s string, name string) (*Config, error) {
if err != nil {
return nil, err
}
peer.AllowedIPs = append(peer.AllowedIPs, *a)
peer.AllowedIPs = append(peer.AllowedIPs, a)
}
case "persistentkeepalive":
p, err := parsePersistentKeepalive(val)
Expand Down Expand Up @@ -399,7 +373,7 @@ func FromDriverConfiguration(interfaze *driver.Interface, existingConfig *Config
}
if p.Flags&driver.PeerHasEndpoint != 0 {
peer.Endpoint.Port = p.Endpoint.Port()
peer.Endpoint.Host = p.Endpoint.IP().String()
peer.Endpoint.Host = p.Endpoint.Addr().String()
}
if p.Flags&driver.PeerHasPersistentKeepalive != 0 {
peer.PersistentKeepalive = p.PersistentKeepalive
Expand All @@ -416,16 +390,13 @@ func FromDriverConfiguration(interfaze *driver.Interface, existingConfig *Config
} else {
a = a.NextAllowedIP()
}
var ip net.IP
var ip netip.Addr
if a.AddressFamily == windows.AF_INET {
ip = a.Address[:4]
ip = netip.AddrFrom4(*(*[4]byte)(a.Address[:4]))
} else if a.AddressFamily == windows.AF_INET6 {
ip = a.Address[:16]
ip = netip.AddrFrom16(*(*[16]byte)(a.Address[:16]))
}
peer.AllowedIPs = append(peer.AllowedIPs, IPCidr{
IP: ip,
Cidr: a.Cidr,
})
peer.AllowedIPs = append(peer.AllowedIPs, IPCidr{netip.PrefixFrom(ip, int(a.Cidr))})
}
conf.Peers = append(conf.Peers, peer)
}
Expand Down
8 changes: 4 additions & 4 deletions conf/parser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@
package conf

import (
"net"
"reflect"
"runtime"
"testing"

"golang.zx2c4.com/go118/netip"
)

const testInput = `
Expand Down Expand Up @@ -77,10 +78,9 @@ func contains(t *testing.T, list, element interface{}) bool {
func TestFromWgQuick(t *testing.T) {
conf, err := FromWgQuick(testInput, "test")
if noError(t, err) {

lenTest(t, conf.Interface.Addresses, 2)
contains(t, conf.Interface.Addresses, IPCidr{net.IPv4(10, 10, 0, 1), uint8(16)})
contains(t, conf.Interface.Addresses, IPCidr{net.IPv4(10, 192, 122, 1), uint8(24)})
contains(t, conf.Interface.Addresses, netip.PrefixFrom(netip.AddrFrom4([4]byte{0, 10, 0, 1}), 16))
contains(t, conf.Interface.Addresses, netip.PrefixFrom(netip.AddrFrom4([4]byte{10, 192, 122, 1}), 24))
equal(t, "yAnz5TF+lXXJte14tji3zlMNq+hd2rYUIgJBgB3fBmk=", conf.Interface.PrivateKey.String())
equal(t, uint16(51820), conf.Interface.ListenPort)

Expand Down
Loading

0 comments on commit 4776166

Please sign in to comment.