diff --git a/net/tcp/listen.go b/net/tcp/listen.go index 0175aec9..45efc19c 100644 --- a/net/tcp/listen.go +++ b/net/tcp/listen.go @@ -76,17 +76,7 @@ func (lf *ListenerFactory) NewListener(v *vrf.VRF, laddr *net.TCPAddr, ttl uint8 } } - if laddr.IP.To4() != nil { - err = unix.Bind(fd, &unix.SockaddrInet4{ - Port: laddr.Port, - Addr: ipv4AddrToArray(laddr.IP), - }) - } else { - err = unix.Bind(fd, &unix.SockaddrInet6{ - Port: laddr.Port, - Addr: ipv6AddrToArray(laddr.IP), - }) - } + err = unix.Bind(fd, netTCPAddrToSockAddr(laddr)) if err != nil { unix.Close(fd) return nil, fmt.Errorf("bind failed: %w", err) diff --git a/net/tcp/tcpsock_posix.go b/net/tcp/tcpsock_posix.go index cb0ea87b..16c84b8d 100644 --- a/net/tcp/tcpsock_posix.go +++ b/net/tcp/tcpsock_posix.go @@ -47,22 +47,7 @@ func dialTCP(afi uint16, laddr, raddr *net.TCPAddr, ttl uint8, md5secret string, } if laddr != nil && laddr.IP != nil { - var bindSA unix.Sockaddr - if laddr.IP.To4() != nil { - la := ipv4AddrToArray(laddr.IP) - bindSA = &unix.SockaddrInet4{ - Port: laddr.Port, - Addr: la, - } - } else { - la := ipv6AddrToArray(laddr.IP) - bindSA = &unix.SockaddrInet6{ - Port: laddr.Port, - Addr: la, - } - } - - err := unix.Bind(fd, bindSA) + err := unix.Bind(fd, netTCPAddrToSockAddr(laddr)) if err != nil { return nil, fmt.Errorf("bind() failed: %w", err) } @@ -78,20 +63,7 @@ func dialTCP(afi uint16, laddr, raddr *net.TCPAddr, ttl uint8, md5secret string, } } - var connectSA unix.Sockaddr - if raddr.IP.To4() != nil { - connectSA = &unix.SockaddrInet4{ - Port: raddr.Port, - Addr: ipv4AddrToArray(raddr.IP), - } - } else { - connectSA = &unix.SockaddrInet6{ - Port: raddr.Port, - Addr: ipv6AddrToArray(raddr.IP), - } - } - - err = unix.Connect(fd, connectSA) + err = unix.Connect(fd, netTCPAddrToSockAddr(raddr)) if err != nil { return nil, fmt.Errorf("connect() failed: %w", err) } @@ -103,15 +75,23 @@ func dialTCP(afi uint16, laddr, raddr *net.TCPAddr, ttl uint8, md5secret string, }, nil } -func ipv6AddrToArray(x net.IP) [16]byte { - return [16]byte{ - x[0], x[1], x[2], x[3], x[4], x[5], x[6], x[7], - x[8], x[9], x[10], x[11], x[12], x[13], x[14], x[15], +func netTCPAddrToSockAddr(tcpAddr *net.TCPAddr) unix.Sockaddr { + ip := tcpAddr.IP + ip4 := ip.To4() + if ip4 != nil { + return &unix.SockaddrInet4{ + Port: tcpAddr.Port, + Addr: [4]byte{ + ip4[0], ip4[1], ip4[2], ip4[3], + }, + } } -} -func ipv4AddrToArray(x net.IP) [4]byte { - return [4]byte{ - x[0], x[1], x[2], x[3], + return &unix.SockaddrInet6{ + Port: tcpAddr.Port, + Addr: [16]byte{ + ip[0], ip[1], ip[2], ip[3], ip[4], ip[5], ip[6], ip[7], + ip[8], ip[9], ip[10], ip[11], ip[12], ip[13], ip[14], ip[15], + }, } } diff --git a/net/tcp/tcpsock_posix_test.go b/net/tcp/tcpsock_posix_test.go new file mode 100644 index 00000000..7f27668e --- /dev/null +++ b/net/tcp/tcpsock_posix_test.go @@ -0,0 +1,50 @@ +package tcp + +import ( + "net" + "testing" + + "github.com/stretchr/testify/assert" + "golang.org/x/sys/unix" +) + +func TestNetTCPAddrToSockAddr(t *testing.T) { + tests := []struct { + name string + addr string + expected unix.Sockaddr + }{ + { + name: "IPv4 / 0.0.0.0", + addr: "0.0.0.0:179", + expected: &unix.SockaddrInet4{ + Port: 179, + Addr: [4]byte{0, 0, 0, 0}, + }, + }, + { + name: "IPv4 / 192.0.2.42", + addr: "192.0.2.42:179", + expected: &unix.SockaddrInet4{ + Port: 179, + Addr: [4]byte{192, 0, 2, 42}, + }, + }, + { + name: "IPv6 / 2001:db8::$2", + addr: "[2001:db8::42]:179", + expected: &unix.SockaddrInet6{ + Port: 179, + Addr: [16]byte{0x20, 0x01, 0x0d, 0xb8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x42}, + }, + }, + } + + for _, test := range tests { + tcpaddr, err := net.ResolveTCPAddr("tcp", test.addr) + if err != nil { + t.Fatalf("Failed to resolve TCPAddr: %v", err) + } + assert.Equal(t, test.expected, netTCPAddrToSockAddr(tcpaddr), test.name) + } +}