diff --git a/capsule.go b/capsule.go index e2b860c..716cd8c 100644 --- a/capsule.go +++ b/capsule.go @@ -164,25 +164,27 @@ func parseAddress(r io.Reader) (requestID uint64, prefix netip.Prefix, _ error) // routeAdvertisementCapsule represents a ROUTE_ADVERTISEMENT capsule type routeAdvertisementCapsule struct { - IPAddressRanges []IPAddressRange + IPAddressRanges []IPRoute } -// IPAddressRange represents an IP Address Range within a ROUTE_ADVERTISEMENT capsule -type IPAddressRange struct { - StartIP netip.Addr - EndIP netip.Addr +// IPRoute represents an IP Address Range +type IPRoute struct { + StartIP netip.Addr + EndIP netip.Addr + // IPProtocol is the Internet Protocol Number for traffic that can be sent to this range. + // If the value is 0, all protocols are allowed. IPProtocol uint8 } -func (r IPAddressRange) len() int { return 1 + r.StartIP.BitLen()/8 + r.EndIP.BitLen()/8 + 1 } +func (r IPRoute) len() int { return 1 + r.StartIP.BitLen()/8 + r.EndIP.BitLen()/8 + 1 } // Prefixes returns the prefixes that this IP address range covers. // Note that depending on the start and end addresses, // this conversion can result in a large number of prefixes. -func (r IPAddressRange) Prefixes() []netip.Prefix { return rangeToPrefixes(r.StartIP, r.EndIP) } +func (r IPRoute) Prefixes() []netip.Prefix { return rangeToPrefixes(r.StartIP, r.EndIP) } func parseRouteAdvertisementCapsule(r io.Reader) (*routeAdvertisementCapsule, error) { - var ranges []IPAddressRange + var ranges []IPRoute for { ipRange, err := parseIPAddressRange(r) if err != nil { @@ -218,10 +220,10 @@ func (c *routeAdvertisementCapsule) append(b []byte) []byte { return b } -func parseIPAddressRange(r io.Reader) (IPAddressRange, error) { +func parseIPAddressRange(r io.Reader) (IPRoute, error) { var ipVersion uint8 if err := binary.Read(r, binary.LittleEndian, &ipVersion); err != nil { - return IPAddressRange{}, err + return IPRoute{}, err } var startIP, endIP netip.Addr @@ -229,36 +231,36 @@ func parseIPAddressRange(r io.Reader) (IPAddressRange, error) { case 4: var start, end [4]byte if _, err := io.ReadFull(r, start[:]); err != nil { - return IPAddressRange{}, err + return IPRoute{}, err } if _, err := io.ReadFull(r, end[:]); err != nil { - return IPAddressRange{}, err + return IPRoute{}, err } startIP = netip.AddrFrom4(start) endIP = netip.AddrFrom4(end) case 6: var start, end [16]byte if _, err := io.ReadFull(r, start[:]); err != nil { - return IPAddressRange{}, err + return IPRoute{}, err } if _, err := io.ReadFull(r, end[:]); err != nil { - return IPAddressRange{}, err + return IPRoute{}, err } startIP = netip.AddrFrom16(start) endIP = netip.AddrFrom16(end) default: - return IPAddressRange{}, fmt.Errorf("invalid IP version: %d", ipVersion) + return IPRoute{}, fmt.Errorf("invalid IP version: %d", ipVersion) } if startIP.Compare(endIP) > 0 { - return IPAddressRange{}, errors.New("start IP is greater than end IP") + return IPRoute{}, errors.New("start IP is greater than end IP") } var ipProtocol uint8 if err := binary.Read(r, binary.LittleEndian, &ipProtocol); err != nil { - return IPAddressRange{}, err + return IPRoute{}, err } - return IPAddressRange{ + return IPRoute{ StartIP: startIP, EndIP: endIP, IPProtocol: ipProtocol, diff --git a/capsule_test.go b/capsule_test.go index c530d54..10f33aa 100644 --- a/capsule_test.go +++ b/capsule_test.go @@ -221,7 +221,7 @@ func TestParseRouteAdvertisementCapsule(t *testing.T) { capsule, err := parseRouteAdvertisementCapsule(cr) require.NoError(t, err) require.Equal(t, - []IPAddressRange{ + []IPRoute{ {StartIP: netip.MustParseAddr("1.1.1.1"), EndIP: netip.MustParseAddr("1.2.3.4"), IPProtocol: 13}, {StartIP: netip.MustParseAddr("2001:db8::1"), EndIP: netip.MustParseAddr("2001:db8::100"), IPProtocol: 37}, }, @@ -240,7 +240,7 @@ func TestParseRouteAdvertisementCapsule(t *testing.T) { func TestWriteRouteAdvertisementCapsule(t *testing.T) { c := &routeAdvertisementCapsule{ - IPAddressRanges: []IPAddressRange{ + IPAddressRanges: []IPRoute{ {StartIP: netip.MustParseAddr("1.1.1.1"), EndIP: netip.MustParseAddr("1.2.3.4"), IPProtocol: 13}, {StartIP: netip.MustParseAddr("2001:db8::1"), EndIP: netip.MustParseAddr("2001:db8::100"), IPProtocol: 37}, }, diff --git a/cert_test.go b/cert_test.go new file mode 100644 index 0000000..3916d47 --- /dev/null +++ b/cert_test.go @@ -0,0 +1,89 @@ +package connectip + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "log" + "math/big" + "time" + + "github.com/quic-go/quic-go/http3" +) + +var ( + tlsConf *tls.Config + certPool *x509.CertPool +) + +func generateCA() (*x509.Certificate, *rsa.PrivateKey, error) { + certTempl := &x509.Certificate{ + SerialNumber: big.NewInt(2019), + Subject: pkix.Name{}, + NotBefore: time.Now(), + NotAfter: time.Now().Add(24 * time.Hour), + IsCA: true, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, + BasicConstraintsValid: true, + } + caPrivateKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return nil, nil, err + } + caBytes, err := x509.CreateCertificate(rand.Reader, certTempl, certTempl, &caPrivateKey.PublicKey, caPrivateKey) + if err != nil { + return nil, nil, err + } + ca, err := x509.ParseCertificate(caBytes) + if err != nil { + return nil, nil, err + } + return ca, caPrivateKey, nil +} + +func generateLeafCert(ca *x509.Certificate, caPrivateKey *rsa.PrivateKey) (*x509.Certificate, *rsa.PrivateKey, error) { + certTempl := &x509.Certificate{ + SerialNumber: big.NewInt(1), + DNSNames: []string{"localhost", "127.0.0.1"}, + NotBefore: time.Now(), + NotAfter: time.Now().Add(24 * time.Hour), + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, + KeyUsage: x509.KeyUsageDigitalSignature, + } + privKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return nil, nil, err + } + certBytes, err := x509.CreateCertificate(rand.Reader, certTempl, ca, &privKey.PublicKey, caPrivateKey) + if err != nil { + return nil, nil, err + } + cert, err := x509.ParseCertificate(certBytes) + if err != nil { + return nil, nil, err + } + return cert, privKey, nil +} + +func init() { + ca, caPrivateKey, err := generateCA() + if err != nil { + log.Fatal("failed to generate CA certificate:", err) + } + leafCert, leafPrivateKey, err := generateLeafCert(ca, caPrivateKey) + if err != nil { + log.Fatal("failed to generate leaf certificate:", err) + } + certPool = x509.NewCertPool() + certPool.AddCert(ca) + tlsConf = &tls.Config{ + Certificates: []tls.Certificate{{ + Certificate: [][]byte{leafCert.Raw}, + PrivateKey: leafPrivateKey, + }}, + NextProtos: []string{http3.NextProtoH3}, + } +} diff --git a/client.go b/client.go new file mode 100644 index 0000000..248c8df --- /dev/null +++ b/client.go @@ -0,0 +1,62 @@ +package connectip + +import ( + "context" + "errors" + "fmt" + "net/http" + "net/url" + + "github.com/quic-go/quic-go/http3" + "github.com/yosida95/uritemplate/v3" +) + +// Dial dials a proxied connection to a target server. +func Dial(ctx context.Context, conn *http3.ClientConn, template *uritemplate.Template) (*Conn, *http.Response, error) { + if len(template.Varnames()) > 0 { + return nil, nil, errors.New("connect-ip-go currently does not support IP flow forwarding") + } + + u, err := url.Parse(template.Raw()) + if err != nil { + return nil, nil, fmt.Errorf("connect-ip: failed to parse URI: %w", err) + } + + select { + case <-ctx.Done(): + return nil, nil, context.Cause(ctx) + case <-conn.Context().Done(): + return nil, nil, context.Cause(conn.Context()) + case <-conn.ReceivedSettings(): + } + settings := conn.Settings() + if !settings.EnableExtendedConnect { + return nil, nil, errors.New("connect-ip: server didn't enable Extended CONNECT") + } + if !settings.EnableDatagrams { + return nil, nil, errors.New("connect-ip: server didn't enable Datagrams") + } + + rstr, err := conn.OpenRequestStream(ctx) + if err != nil { + return nil, nil, fmt.Errorf("connect-ip: failed to open request stream: %w", err) + } + if err := rstr.SendRequestHeader(&http.Request{ + Method: http.MethodConnect, + Proto: requestProtocol, + Host: u.Host, + Header: http.Header{capsuleHeader: []string{capsuleProtocolHeaderValue}}, + URL: u, + }); err != nil { + return nil, nil, fmt.Errorf("connect-ip: failed to send request: %w", err) + } + // TODO: optimistically return the connection + rsp, err := rstr.ReadResponse() + if err != nil { + return nil, nil, fmt.Errorf("connect-ip: failed to read response: %w", err) + } + if rsp.StatusCode < 200 || rsp.StatusCode > 299 { + return nil, rsp, fmt.Errorf("connect-ip: server responded with %d", rsp.StatusCode) + } + return newProxiedConn(rstr), rsp, nil +} diff --git a/conn.go b/conn.go new file mode 100644 index 0000000..db811b5 --- /dev/null +++ b/conn.go @@ -0,0 +1,330 @@ +package connectip + +import ( + "context" + "errors" + "fmt" + "log" + "net/netip" + "slices" + "sync" + + "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" + + "github.com/quic-go/quic-go/http3" + "github.com/quic-go/quic-go/quicvarint" +) + +type appendable interface{ append([]byte) []byte } + +type writeCapsule struct { + capsule appendable + result chan error +} + +// Conn is a connection that proxies IP packets over HTTP/3. +type Conn struct { + str http3.Stream + writes chan writeCapsule + + assignedAddressNotify chan struct{} + availableRoutesNotify chan struct{} + + mu sync.Mutex + peerAddresses []netip.Prefix // IP prefixes that we assigned to the peer + localRoutes []IPRoute // IP routes that we advertised to the peer + assignedAddresses []netip.Prefix + availableRoutes []IPRoute +} + +func newProxiedConn(str http3.Stream) *Conn { + c := &Conn{ + str: str, + writes: make(chan writeCapsule), + assignedAddressNotify: make(chan struct{}, 1), + availableRoutesNotify: make(chan struct{}, 1), + } + go func() { + if err := c.readFromStream(); err != nil { + log.Printf("handling stream failed: %v", err) + } + }() + go func() { + if err := c.writeToStream(); err != nil { + log.Printf("writing to stream failed: %v", err) + } + }() + return c +} + +// AdvertiseRoute informs the peer about available routes. +// This function can be called multiple times, but only the routes from the most recent call will be active. +// Previous route advertisements are overwritten by each new call to this function. +func (c *Conn) AdvertiseRoute(ctx context.Context, routes []IPRoute) error { + for _, route := range routes { + if route.StartIP.Compare(route.EndIP) == 1 { + return fmt.Errorf("invalid route advertising start_ip: %s larger than %s", route.StartIP, route.EndIP) + } + } + c.mu.Lock() + c.localRoutes = slices.Clone(routes) + c.mu.Unlock() + return c.sendCapsule(ctx, &routeAdvertisementCapsule{IPAddressRanges: routes}) +} + +// AssignAddresses assigned address prefixes to the peer. +// This function can be called multiple times, but only the addresses from the most recent call will be active. +// Previous address assignments are overwritten by each new call to this function. +func (c *Conn) AssignAddresses(ctx context.Context, prefixes []netip.Prefix) error { + c.mu.Lock() + c.peerAddresses = slices.Clone(prefixes) + c.mu.Unlock() + capsule := &addressAssignCapsule{AssignedAddresses: make([]AssignedAddress, 0, len(prefixes))} + for _, p := range prefixes { + capsule.AssignedAddresses = append(capsule.AssignedAddresses, AssignedAddress{IPPrefix: p}) + } + return c.sendCapsule(ctx, capsule) +} + +func (c *Conn) sendCapsule(ctx context.Context, capsule appendable) error { + res := make(chan error, 1) + select { + case c.writes <- writeCapsule{ + capsule: capsule, + result: res, + }: + select { + case <-ctx.Done(): + return ctx.Err() + case err := <-res: + return err + } + case <-ctx.Done(): + return ctx.Err() + } +} + +// LocalPrefixes returns the prefixes that the peer currently assigned. +// Note that at any point during the connection, the peer can change the assignment. +// It is therefore recommended to call this function in a loop. +func (c *Conn) LocalPrefixes(ctx context.Context) ([]netip.Prefix, error) { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-c.assignedAddressNotify: + c.mu.Lock() + defer c.mu.Unlock() + return c.assignedAddresses, nil + } +} + +// Routes returns the routes that the peer currently advertised. +// Note that at any point during the connection, the peer can change the advertised routes. +// It is therefore recommended to call this function in a loop. +func (c *Conn) Routes(ctx context.Context) ([]IPRoute, error) { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-c.availableRoutesNotify: + c.mu.Lock() + defer c.mu.Unlock() + return c.availableRoutes, nil + } +} + +func (c *Conn) readFromStream() error { + defer c.str.Close() + r := quicvarint.NewReader(c.str) + for { + t, cr, err := http3.ParseCapsule(r) + if err != nil { + return err + } + switch t { + case capsuleTypeAddressAssign: + capsule, err := parseAddressAssignCapsule(cr) + if err != nil { + return err + } + prefixes := make([]netip.Prefix, 0, len(capsule.AssignedAddresses)) + for _, assigned := range capsule.AssignedAddresses { + prefixes = append(prefixes, assigned.IPPrefix) + } + c.mu.Lock() + c.assignedAddresses = prefixes + c.mu.Unlock() + select { + case c.assignedAddressNotify <- struct{}{}: + default: + } + case capsuleTypeAddressRequest: + if _, err := parseAddressRequestCapsule(cr); err != nil { + return err + } + return errors.New("connect-ip: address request not yet supported") + case capsuleTypeRouteAdvertisement: + capsule, err := parseRouteAdvertisementCapsule(cr) + if err != nil { + return err + } + c.mu.Lock() + c.availableRoutes = capsule.IPAddressRanges + c.mu.Unlock() + select { + case c.availableRoutesNotify <- struct{}{}: + default: + } + default: + return fmt.Errorf("unknown capsule type: %d", t) + } + } +} + +func (c *Conn) writeToStream() error { + buf := make([]byte, 0, 1024) + for { + req, ok := <-c.writes + if !ok { + return nil + } + buf = req.capsule.append(buf[:0]) + _, err := c.str.Write(buf) + req.result <- err + if err != nil { + return err + } + } +} + +func (c *Conn) Read(b []byte) (n int, err error) { +start: + data, err := c.str.ReceiveDatagram(context.Background()) + if err != nil { + return 0, err + } + contextID, n, err := quicvarint.Parse(data) + if err != nil { + // TODO: close connection + return 0, fmt.Errorf("connect-ip: malformed datagram: %w", err) + } + if contextID != 0 { + // Drop this datagram. We currently only support proxying of IP payloads. + goto start + } + if err := c.handleIncomingProxiedPacket(data[n:]); err != nil { + log.Printf("dropping proxied packet: %s", err) + goto start + } + return copy(b, data[n:]), nil +} + +func (c *Conn) handleIncomingProxiedPacket(data []byte) error { + if len(data) == 0 { + return errors.New("connect-ip: empty packet") + } + var src, dst netip.Addr + var ipProto uint8 + switch v := ipVersion(data); v { + default: + return fmt.Errorf("connect-ip: unknown IP versions: %d", v) + case 4: + if len(data) < ipv4.HeaderLen { + return fmt.Errorf("connect-ip: malformed datagram: too short") + } + src = netip.AddrFrom4([4]byte(data[12:16])) + dst = netip.AddrFrom4([4]byte(data[16:20])) + ipProto = data[9] + case 6: + if len(data) < ipv6.HeaderLen { + return fmt.Errorf("connect-ip: malformed datagram: too short") + } + src = netip.AddrFrom16([16]byte(data[8:24])) + dst = netip.AddrFrom16([16]byte(data[24:40])) + ipProto = data[6] + } + + c.mu.Lock() + assignedAddresses := c.assignedAddresses + localRoutes := c.localRoutes + peerAddresses := c.peerAddresses + c.mu.Unlock() + + // We don't necessarily assign any addresses to the peer. + // For example, in the Remote Access VPN use case (RFC 9484, section 8.1), + // the client accepts incoming traffic from all IPs. + if peerAddresses != nil { + if !slices.ContainsFunc(peerAddresses, func(p netip.Prefix) bool { return p.Contains(src) }) { + // TODO: send ICMP + return fmt.Errorf("connect-ip: datagram source address not allowed: %s", src) + } + } + + // The destination IP address is valid if it + // 1. is within one of the ranges assigned to us, or + // 2. is within one of the ranges that we advertised to the peer. + var isAllowedDst bool + if len(assignedAddresses) > 0 { + isAllowedDst = slices.ContainsFunc(assignedAddresses, func(p netip.Prefix) bool { return p.Contains(dst) }) + } + if !isAllowedDst { + isAllowedDst = slices.ContainsFunc(localRoutes, func(r IPRoute) bool { + if r.StartIP.Compare(dst) > 0 || dst.Compare(r.EndIP) > 0 { + return false + } + // TODO: walk the chain of IPv6 extensions + // See section 4.8 of RFC 9484 for details. + return ipProto == 0 || r.IPProtocol == 0 || r.IPProtocol == ipProto + }) + } + if !isAllowedDst { + // TODO: send ICMP + return fmt.Errorf("connect-ip: datagram destination address / protocol not allowed: %s (protocol: %d)", dst, ipProto) + } + return nil +} + +func (c *Conn) Write(b []byte) (n int, err error) { + data, err := c.composeDatagram(b) + if err != nil { + log.Printf("dropping proxied packet (%d bytes) that can't be proxied: %s", len(b), err) + return 0, nil + } + return len(b), c.str.SendDatagram(data) +} + +func (c *Conn) composeDatagram(b []byte) ([]byte, error) { + // TODO: implement src, dst and ipproto checks + if len(b) == 0 { + return nil, nil + } + switch v := ipVersion(b); v { + default: + return nil, fmt.Errorf("connect-ip: unknown IP versions: %d", v) + case 4: + if len(b) < ipv4.HeaderLen { + return nil, fmt.Errorf("connect-ip: IPv4 packet too short") + } + ttl := b[8] + if ttl <= 1 { + return nil, fmt.Errorf("connect-ip: datagram TTL too small: %d", ttl) + } + b[8]-- // Decrement TTL + // TODO: maybe recalculate the checksum? + case 6: + if len(b) < ipv6.HeaderLen { + return nil, fmt.Errorf("connect-ip: IPv6 packet too short") + } + hopLimit := b[7] + if hopLimit <= 1 { + return nil, fmt.Errorf("connect-ip: datagram Hop Limit too small: %d", hopLimit) + } + b[7]-- // Decrement Hop Limit + } + data := make([]byte, 0, len(contextIDZero)+len(b)) + data = append(data, contextIDZero...) + data = append(data, b...) + return data, nil +} + +func ipVersion(b []byte) uint8 { return b[0] >> 4 } diff --git a/conn_test.go b/conn_test.go new file mode 100644 index 0000000..cece306 --- /dev/null +++ b/conn_test.go @@ -0,0 +1,224 @@ +package connectip + +import ( + "context" + "net" + "net/netip" + "testing" + "time" + + "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" + + "github.com/quic-go/quic-go" + "github.com/quic-go/quic-go/http3" + + "github.com/stretchr/testify/require" +) + +var ipv6Header = []byte{ + 0x60, 0x00, 0x00, 0x00, // Version, Traffic Class, Flow Label + 0x00, 0x20, 59, 64, // Payload Length, Next Header, Hop Limit + 0x20, 0x01, 0x0d, 0xb8, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, // Source IP + 0x20, 0x01, 0x0d, 0xb8, 0x85, 0xa3, 0x08, 0xd3, 0x13, 0x19, 0x8a, 0x2e, 0x03, 0x70, 0x73, 0x48, // Destination IP +} + +type mockStream struct { + reading []byte + toRead <-chan []byte +} + +var _ http3.Stream = &mockStream{} + +func (m *mockStream) StreamID() quic.StreamID { panic("implement me") } +func (m *mockStream) Read(p []byte) (int, error) { + if m.reading == nil { + m.reading = <-m.toRead + } + n := copy(p, m.reading) + m.reading = m.reading[n:] + return n, nil +} +func (m *mockStream) CancelRead(quic.StreamErrorCode) {} +func (m *mockStream) Write(p []byte) (n int, err error) { return len(p), nil } +func (m *mockStream) Close() error { return nil } +func (m *mockStream) CancelWrite(quic.StreamErrorCode) {} +func (m *mockStream) Context() context.Context { return context.Background() } +func (m *mockStream) SetWriteDeadline(time.Time) error { return nil } +func (m *mockStream) SetReadDeadline(time.Time) error { return nil } +func (m *mockStream) SetDeadline(time.Time) error { return nil } +func (m *mockStream) SendDatagram(data []byte) error { return nil } +func (m *mockStream) ReceiveDatagram(ctx context.Context) ([]byte, error) { + <-ctx.Done() + return nil, ctx.Err() +} + +func TestDatagramParsing(t *testing.T) { + t.Run("empty packets", func(t *testing.T) { + conn := newProxiedConn(&mockStream{}) + require.ErrorContains(t, + conn.handleIncomingProxiedPacket([]byte{}), + "connect-ip: empty packet", + ) + }) + + t.Run("invalid IP version", func(t *testing.T) { + conn := newProxiedConn(&mockStream{}) + data := make([]byte, 20) + data[0] = 5 << 4 // IPv5 + require.ErrorContains(t, + conn.handleIncomingProxiedPacket(data), + "connect-ip: unknown IP versions: 5", + ) + }) + + t.Run("IPv4 packet too short", func(t *testing.T) { + conn := newProxiedConn(&mockStream{}) + data, err := (&ipv4.Header{ + Src: net.IPv4(1, 2, 3, 4), + Dst: net.IPv4(159, 70, 42, 98), + Len: 20, + Checksum: 89, + }).Marshal() + require.NoError(t, err) + require.ErrorContains(t, + conn.handleIncomingProxiedPacket(data[:ipv4.HeaderLen-1]), + "connect-ip: malformed datagram: too short", + ) + }) + + t.Run("IPv6 packet too short", func(t *testing.T) { + conn := newProxiedConn(&mockStream{}) + require.ErrorContains(t, + conn.handleIncomingProxiedPacket(ipv6Header[:ipv6.HeaderLen-1]), + "connect-ip: malformed datagram: too short", + ) + }) + + t.Run("invalid source address", func(t *testing.T) { + conn := newProxiedConn(&mockStream{}) + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + require.NoError(t, conn.AssignAddresses(ctx, []netip.Prefix{netip.MustParsePrefix("192.168.0.10/32")})) + hdr := &ipv4.Header{ + Src: net.IPv4(192, 168, 0, 11), + Dst: net.IPv4(159, 70, 42, 98), + Len: 20, + Checksum: 89, + } + data, err := hdr.Marshal() + require.NoError(t, err) + require.ErrorContains(t, + conn.handleIncomingProxiedPacket(data), + "connect-ip: datagram source address not allowed: 192.168.0.11", + ) + }) + + t.Run("invalid destination address", func(t *testing.T) { + conn := newProxiedConn(&mockStream{}) + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + require.NoError(t, conn.AssignAddresses(ctx, []netip.Prefix{netip.MustParsePrefix("192.168.0.10/32")})) + require.NoError(t, conn.AdvertiseRoute(ctx, []IPRoute{ + {StartIP: netip.MustParseAddr("10.0.0.0"), EndIP: netip.MustParseAddr("10.1.2.3")}, + })) + hdr := &ipv4.Header{ + Src: net.IPv4(192, 168, 0, 10), + Dst: net.IPv4(10, 1, 2, 3), + Len: 20, + Checksum: 89, + } + data, err := hdr.Marshal() + require.NoError(t, err) + require.NoError(t, conn.handleIncomingProxiedPacket(data)) + + // 10.1.2.4 is outside the range of allowed addresses + hdr.Dst = net.IPv4(10, 1, 2, 4) + data, err = hdr.Marshal() + require.NoError(t, err) + require.ErrorContains(t, + conn.handleIncomingProxiedPacket(data), + "connect-ip: datagram destination address / protocol not allowed: 10.1.2.4 (protocol: 0)", + ) + }) + + t.Run("invalid IP protocol", func(t *testing.T) { + conn := newProxiedConn(&mockStream{}) + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + require.NoError(t, conn.AssignAddresses(ctx, []netip.Prefix{netip.MustParsePrefix("192.168.0.10/32")})) + require.NoError(t, conn.AdvertiseRoute(ctx, []IPRoute{ + {StartIP: netip.MustParseAddr("10.0.0.0"), EndIP: netip.MustParseAddr("10.1.2.3"), IPProtocol: 42}, + })) + hdr := &ipv4.Header{ + Src: net.IPv4(192, 168, 0, 10), + Dst: net.IPv4(10, 1, 2, 3), + Len: 20, + Checksum: 89, + Protocol: 42, + } + data, err := hdr.Marshal() + require.NoError(t, err) + require.NoError(t, conn.handleIncomingProxiedPacket(data)) + + hdr.Protocol = 41 + data, err = hdr.Marshal() + require.NoError(t, err) + require.ErrorContains(t, + conn.handleIncomingProxiedPacket(data), + "connect-ip: datagram destination address / protocol not allowed: 10.1.2.3 (protocol: 41)", + ) + + // ICMP is always allowed + hdr.Protocol = 0 + data, err = hdr.Marshal() + require.NoError(t, err) + require.NoError(t, conn.handleIncomingProxiedPacket(data)) + }) + + t.Run("packet from assigned address", func(t *testing.T) { + readChan := make(chan []byte, 1) + conn := newProxiedConn(&mockStream{toRead: readChan}) + + hdr := &ipv4.Header{ + Src: net.IPv4(159, 70, 42, 98), + Dst: net.IPv4(192, 168, 0, 10), + Len: 20, + Checksum: 89, + } + data, err := hdr.Marshal() + require.NoError(t, err) + require.Error(t, conn.handleIncomingProxiedPacket(data), "connect-ip: datagram destination address") + + // now assign 192.168.0.11 to this connection + readChan <- (&addressAssignCapsule{ + AssignedAddresses: []AssignedAddress{{IPPrefix: netip.MustParsePrefix("192.168.0.10/32")}}, + }).append(nil) + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + _, err = conn.LocalPrefixes(ctx) + require.NoError(t, err) + // after processing the address assignment, this is a valid packet + require.NoError(t, conn.handleIncomingProxiedPacket(data)) + }) +} + +func FuzzIncomingDatagram(f *testing.F) { + conn := newProxiedConn(&mockStream{}) + + ipv4Header, err := (&ipv4.Header{ + Src: net.IPv4(1, 2, 3, 4), + Dst: net.IPv4(159, 70, 42, 98), + Len: 20, + Checksum: 89, + }).Marshal() + require.NoError(f, err) + + f.Add(ipv4Header) + f.Add(ipv6Header) + + f.Fuzz(func(t *testing.T, data []byte) { + conn.handleIncomingProxiedPacket(data) + }) +} diff --git a/go.mod b/go.mod index 4d3c44f..4656328 100644 --- a/go.mod +++ b/go.mod @@ -7,6 +7,7 @@ require ( github.com/quic-go/quic-go v0.48.0 github.com/stretchr/testify v1.9.0 github.com/yosida95/uritemplate/v3 v3.0.2 + golang.org/x/net v0.28.0 ) require ( @@ -21,7 +22,6 @@ require ( golang.org/x/crypto v0.26.0 // indirect golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 // indirect golang.org/x/mod v0.17.0 // indirect - golang.org/x/net v0.28.0 // indirect golang.org/x/sys v0.23.0 // indirect golang.org/x/text v0.17.0 // indirect golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d // indirect diff --git a/proxy.go b/proxy.go new file mode 100644 index 0000000..2da4c18 --- /dev/null +++ b/proxy.go @@ -0,0 +1,20 @@ +package connectip + +import ( + "net/http" + + "github.com/quic-go/quic-go/http3" + "github.com/quic-go/quic-go/quicvarint" +) + +var contextIDZero = quicvarint.Append([]byte{}, 0) + +type Proxy struct{} + +func (s *Proxy) Proxy(w http.ResponseWriter, _ *Request) (*Conn, error) { + w.Header().Set(capsuleHeader, capsuleProtocolHeaderValue) + w.WriteHeader(http.StatusOK) + + str := w.(http3.HTTPStreamer).HTTPStream() + return newProxiedConn(str), nil +} diff --git a/proxy_test.go b/proxy_test.go new file mode 100644 index 0000000..46622e7 --- /dev/null +++ b/proxy_test.go @@ -0,0 +1,138 @@ +package connectip + +import ( + "context" + "crypto/tls" + "fmt" + "net" + "net/http" + "net/netip" + "testing" + "time" + + "github.com/quic-go/quic-go" + "github.com/quic-go/quic-go/http3" + "github.com/yosida95/uritemplate/v3" + + "github.com/stretchr/testify/require" +) + +func setupConns(t *testing.T) (client, server *Conn) { + t.Helper() + + p := &Proxy{} + conn, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0}) + require.NoError(t, err) + t.Cleanup(func() { conn.Close() }) + + template := uritemplate.MustNew(fmt.Sprintf("https://localhost:%d/connect-ip", conn.LocalAddr().(*net.UDPAddr).Port)) + connChan := make(chan *Conn, 1) + mux := http.NewServeMux() + mux.HandleFunc("/connect-ip", func(w http.ResponseWriter, r *http.Request) { + mreq, err := ParseRequest(r, template) + require.NoError(t, err) + + conn, err := p.Proxy(w, mreq) + require.NoError(t, err) + connChan <- conn + }) + s := http3.Server{ + Handler: mux, + Addr: ":0", + EnableDatagrams: true, + TLSConfig: tlsConf, + } + go func() { s.Serve(conn) }() + t.Cleanup(func() { s.Close() }) + + udpConn, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0}) + require.NoError(t, err) + t.Cleanup(func() { udpConn.Close() }) + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + cconn, err := quic.Dial( + ctx, + udpConn, + conn.LocalAddr(), + &tls.Config{ServerName: "localhost", RootCAs: certPool, NextProtos: []string{http3.NextProtoH3}}, + &quic.Config{EnableDatagrams: true}, + ) + require.NoError(t, err) + tr := &http3.Transport{EnableDatagrams: true} + t.Cleanup(func() { tr.Close() }) + + client, rsp, err := Dial(ctx, tr.NewClientConn(cconn), template) + require.NoError(t, err) + require.Equal(t, rsp.StatusCode, http.StatusOK) + + select { + case <-time.After(time.Second): + t.Fatal("timed out") + case conn := <-connChan: + return client, conn + } + return client, server +} + +func TestAddressAssignment(t *testing.T) { + client, server := setupConns(t) + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) + defer cancel() + _, err := server.Routes(ctx) + require.ErrorIs(t, err, context.DeadlineExceeded) + + ctx, cancel = context.WithTimeout(context.Background(), time.Second) + defer cancel() + pref1 := netip.MustParsePrefix("1.1.1.0/24") + pref2 := netip.MustParsePrefix("2001:db8::/64") + require.NoError(t, client.AssignAddresses(ctx, []netip.Prefix{pref1, pref2})) + routes, err := server.LocalPrefixes(ctx) + require.NoError(t, err) + require.Equal(t, []netip.Prefix{pref1, pref2}, routes) + + // addresses are replaced once a new capsule is received + require.NoError(t, client.AssignAddresses(ctx, []netip.Prefix{})) + routes, err = server.LocalPrefixes(ctx) + require.NoError(t, err) + require.Empty(t, routes) +} + +func TestRouteAdvertisement(t *testing.T) { + client, server := setupConns(t) + + // no routes available + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) + defer cancel() + _, err := server.Routes(ctx) + require.ErrorIs(t, err, context.DeadlineExceeded) + + ctx, cancel = context.WithTimeout(context.Background(), time.Second) + defer cancel() + + // refuse to advertise invalid routes + require.ErrorContains(t, + client.AdvertiseRoute(ctx, []IPRoute{ + {StartIP: netip.MustParseAddr("1.1.1.2"), EndIP: netip.MustParseAddr("1.1.1.1"), IPProtocol: 42}, + }), + "invalid route advertising start_ip: 1.1.1.2 larger than 1.1.1.1", + ) + + // advertise some routes and make sure they're received + require.NoError(t, client.AdvertiseRoute(ctx, []IPRoute{ + {StartIP: netip.MustParseAddr("1.1.1.1"), EndIP: netip.MustParseAddr("2.2.2.2"), IPProtocol: 42}, + {StartIP: netip.MustParseAddr("2001:db8::1"), EndIP: netip.MustParseAddr("2001:db8::100"), IPProtocol: 24}, + })) + routes, err := server.Routes(ctx) + require.NoError(t, err) + require.Equal(t, []IPRoute{ + {StartIP: netip.MustParseAddr("1.1.1.1"), EndIP: netip.MustParseAddr("2.2.2.2"), IPProtocol: 42}, + {StartIP: netip.MustParseAddr("2001:db8::1"), EndIP: netip.MustParseAddr("2001:db8::100"), IPProtocol: 24}, + }, routes) + + // routes are replaced once a new capsule is received + require.NoError(t, client.AdvertiseRoute(ctx, []IPRoute{})) + routes, err = server.Routes(ctx) + require.NoError(t, err) + require.Empty(t, routes) +}