diff --git a/api/tunnel.go b/api/tunnel.go index 96b8560..367301f 100644 --- a/api/tunnel.go +++ b/api/tunnel.go @@ -162,7 +162,6 @@ func tunnelAcceptLoop(ctx context.Context, id string, li net.Listener, tun Tunne if err != nil { log.Printf("error serving local connection %s: %v\n", id, err) } - cEvt.OnDisconnected(ctx, err) }(c) } } diff --git a/go.mod b/go.mod index aa07687..45f1e59 100644 --- a/go.mod +++ b/go.mod @@ -16,11 +16,13 @@ require ( github.com/google/uuid v1.6.0 github.com/martinlindhe/base36 v1.1.1 github.com/pomerium/pomerium v0.28.0 + github.com/quic-go/quic-go v0.48.2 github.com/rs/zerolog v1.33.0 github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966 github.com/spf13/cobra v1.8.1 github.com/stretchr/testify v1.10.0 golang.org/x/crypto v0.29.0 + golang.org/x/net v0.30.0 golang.org/x/sync v0.9.0 golang.org/x/sys v0.27.0 google.golang.org/grpc v1.68.0 @@ -37,6 +39,8 @@ require ( github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/envoyproxy/go-control-plane v0.13.1 // indirect github.com/envoyproxy/protoc-gen-validate v1.1.0 // indirect + github.com/go-task/slim-sprig/v3 v3.0.0 // indirect + github.com/google/pprof v0.0.0-20240424215950-a892ee059fd6 // indirect github.com/hashicorp/errwrap v1.1.0 // indirect github.com/hashicorp/go-multierror v1.1.1 // indirect github.com/hashicorp/go-set/v3 v3.0.0 // indirect @@ -49,15 +53,18 @@ require ( github.com/mattn/go-isatty v0.0.20 // indirect github.com/mholt/acmez/v2 v2.0.3 // indirect github.com/miekg/dns v1.1.62 // indirect + github.com/onsi/ginkgo/v2 v2.19.1 // indirect github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/pomerium/protoutil v0.0.0-20240813175624-47b7ac43ff46 // indirect + github.com/quic-go/qpack v0.5.1 // indirect github.com/spf13/pflag v1.0.5 // indirect github.com/zeebo/blake3 v0.2.4 // indirect + go.uber.org/mock v0.5.0 // indirect go.uber.org/multierr v1.11.0 // indirect go.uber.org/zap v1.27.0 // indirect + golang.org/x/exp v0.0.0-20240808152545-0cdaa3abc0fa // indirect golang.org/x/mod v0.20.0 // indirect - golang.org/x/net v0.30.0 // indirect golang.org/x/text v0.20.0 // indirect golang.org/x/tools v0.24.0 // indirect google.golang.org/genproto/googleapis/api v0.0.0-20241007155032-5fefd90f89a9 // indirect diff --git a/go.sum b/go.sum index 775a9fd..fbd39d9 100644 --- a/go.sum +++ b/go.sum @@ -95,6 +95,8 @@ github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= github.com/go-ole/go-ole v1.3.0 h1:Dt6ye7+vXGIKZ7Xtk4s6/xVdGDQynvom7xCFEdWr6uE= github.com/go-ole/go-ole v1.3.0/go.mod h1:5LS6F96DhAwUc7C+1HLexzMXY1xGRSryjyPPKW6zv78= +github.com/go-task/slim-sprig/v3 v3.0.0 h1:sUs3vkvUymDpBKi3qH1YSqBQk9+9D/8M2mN1vB6EwHI= +github.com/go-task/slim-sprig/v3 v3.0.0/go.mod h1:W848ghGpv3Qj3dhTPRyJypKRiqCdHZiAzKg9hl15HA8= github.com/goccy/go-json v0.10.3 h1:KZ5WoDbxAIgm2HNbYckL0se1fHD6rz5j4ywS6ebzDqA= github.com/goccy/go-json v0.10.3/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= @@ -111,6 +113,8 @@ github.com/google/btree v1.1.3/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl76 github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/pprof v0.0.0-20240424215950-a892ee059fd6 h1:k7nVchz72niMH6YLQNvHSdIE7iqsQxK1P41mySCvssg= +github.com/google/pprof v0.0.0-20240424215950-a892ee059fd6/go.mod h1:kf6iHlnVGwgKolg33glAes7Yg/8iWP8ukqeldJSO7jw= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY= @@ -192,6 +196,10 @@ github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A= github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= +github.com/onsi/ginkgo/v2 v2.19.1 h1:QXgq3Z8Crl5EL1WBAC98A5sEBHARrAJNzAmMxzLcRF0= +github.com/onsi/ginkgo/v2 v2.19.1/go.mod h1:O3DtEWQkPa/F7fBMgmZQKKsluAy8pd3rEQdrjkPb9zA= +github.com/onsi/gomega v1.34.1 h1:EUMJIKUjM8sKjYbtxQI9A4z2o+rruxnzNvpknOXie6k= +github.com/onsi/gomega v1.34.1/go.mod h1:kU1QgUvBDLXBJq618Xvm2LUX6rSAfRaFRTcdOeDLwwY= github.com/open-policy-agent/opa v0.70.0 h1:B3cqCN2iQAyKxK6+GI+N40uqkin+wzIrM7YA60t9x1U= github.com/open-policy-agent/opa v0.70.0/go.mod h1:Y/nm5NY0BX0BqjBriKUiV81sCl8XOjjvqQG7dXrggtI= github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U= @@ -231,6 +239,10 @@ github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0leargg github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk= github.com/prometheus/statsd_exporter v0.22.7 h1:7Pji/i2GuhK6Lu7DHrtTkFmNBCudCPT1pX2CziuyQR0= github.com/prometheus/statsd_exporter v0.22.7/go.mod h1:N/TevpjkIh9ccs6nuzY3jQn9dFqnUakOjnEuMPJJJnI= +github.com/quic-go/qpack v0.5.1 h1:giqksBPnT/HDtZ6VhtFKgoLOWmlyo9Ei6u9PqzIMbhI= +github.com/quic-go/qpack v0.5.1/go.mod h1:+PC4XFrEskIVkcLzpEkbLqq1uCoxPhQuvK5rH1ZgaEg= +github.com/quic-go/quic-go v0.48.2 h1:wsKXZPeGWpMpCGSWqOcqpW2wZYic/8T3aqiOID0/KWE= +github.com/quic-go/quic-go v0.48.2/go.mod h1:yBgs3rWBOADpga7F+jJsb6Ybg1LSYiQvwWlLX+/6HMs= github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475 h1:N/ElC8H3+5XpJzTSTfLsJV/mx9Q9g7kxmchpfZyxgzM= github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4= github.com/rogpeppe/go-charset v0.0.0-20180617210344-2471d30d28b4/go.mod h1:qgYeAmZ5ZIpBWTGllZSQnw97Dj+woV0toclVaRGI8pc= @@ -368,6 +380,8 @@ golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/text v0.20.0 h1:gK/Kv2otX8gz+wn7Rmb3vT96ZwuoxnQlY+HlJVj7Qug= golang.org/x/text v0.20.0/go.mod h1:D4IsuqiFMhST5bX19pQ9ikHC2GsaKyk/oF+pn3ducp4= +golang.org/x/time v0.7.0 h1:ntUhktv3OPE6TgYxXWv9vKvUSJyIFJlyohwbkEwPrKQ= +golang.org/x/time v0.7.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= diff --git a/internal/testutil/testutil.go b/internal/testutil/testutil.go new file mode 100644 index 0000000..1de0826 --- /dev/null +++ b/internal/testutil/testutil.go @@ -0,0 +1,23 @@ +package testutil + +import ( + "net" + "testing" + + "github.com/stretchr/testify/require" +) + +// GetPort gets a free port. +func GetPort(t *testing.T) string { + t.Helper() + + li, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + + _, port, err := net.SplitHostPort(li.Addr().String()) + require.NoError(t, err) + + _ = li.Close() + + return port +} diff --git a/tunnel/tunnel.go b/tunnel/tunnel.go index 9f47a96..ef60c7e 100644 --- a/tunnel/tunnel.go +++ b/tunnel/tunnel.go @@ -2,16 +2,14 @@ package tunnel import ( - "bufio" "context" - "crypto/tls" "errors" "fmt" "io" "log" "net" - "net/http" "net/url" + "sync" "time" backoff "github.com/cenkalti/backoff/v4" @@ -20,10 +18,19 @@ import ( "github.com/pomerium/cli/jwt" ) +var ( + errUnavailable = errors.New("unavailable") + errUnauthenticated = errors.New("unauthenticated") + errUnsupported = errors.New("unsupported") +) + // A Tunnel represents a TCP tunnel over HTTP Connect. type Tunnel struct { cfg *config auth *authclient.AuthClient + + mu sync.Mutex + tcpTunneler TCPTunneler } // New creates a new Tunnel. @@ -106,121 +113,46 @@ func (tun *Tunnel) Run(ctx context.Context, local io.ReadWriter, eventSink Event } func (tun *Tunnel) run(ctx context.Context, eventSink EventSink, local io.ReadWriter, rawJWT string, retryCount int) error { - eventSink.OnConnecting(ctx) - - hdr := http.Header{} - if rawJWT != "" { - hdr.Set("Authorization", "Pomerium "+rawJWT) - } - - req := (&http.Request{ - Method: "CONNECT", - URL: &url.URL{Opaque: tun.cfg.dstHost}, - Host: tun.cfg.dstHost, - Header: hdr, - }).WithContext(ctx) - - var remote net.Conn - var err error - if tun.cfg.tlsConfig != nil { - remote, err = (&tls.Dialer{Config: tun.cfg.tlsConfig}).DialContext(ctx, "tcp", tun.cfg.proxyHost) - } else { - remote, err = (&net.Dialer{}).DialContext(ctx, "tcp", tun.cfg.proxyHost) - } - if err != nil { - return fmt.Errorf("failed to establish connection to proxy: %w", err) - } - defer func() { - _ = remote.Close() - log.Println("connection closed") - }() - if done := ctx.Done(); done != nil { - go func() { - <-done - _ = remote.Close() - }() - } - - err = req.Write(remote) - if err != nil { - return err + tun.mu.Lock() + if tun.tcpTunneler == nil { + tun.tcpTunneler = tun.pickTCPTunneler(ctx) } + tun.mu.Unlock() - br := bufio.NewReader(remote) - res, err := http.ReadResponse(br, req) - if err != nil { - return fmt.Errorf("failed to read HTTP response: %w", err) - } - defer func() { - _ = res.Body.Close() - }() - switch res.StatusCode { - case http.StatusOK: - case http.StatusServiceUnavailable: + err := tun.tcpTunneler.TunnelTCP(ctx, eventSink, local, rawJWT) + if errors.Is(err, errUnavailable) { // don't delete the JWT if we get a service unavailable - return fmt.Errorf("invalid http response code: %s", res.Status) - case http.StatusMovedPermanently, - http.StatusFound, - http.StatusTemporaryRedirect, - http.StatusPermanentRedirect: - if retryCount == 0 { - _ = remote.Close() - - serverURL := &url.URL{ - Scheme: "http", - Host: tun.cfg.proxyHost, - } - if tun.cfg.tlsConfig != nil { - serverURL.Scheme = "https" - } - - rawJWT, err = tun.auth.GetJWT(ctx, serverURL, func(authURL string) { eventSink.OnAuthRequired(ctx, authURL) }) - if err != nil { - return fmt.Errorf("failed to get authentication JWT: %w", err) - } - - err = tun.cfg.jwtCache.StoreJWT(tun.jwtCacheKey(), rawJWT) - if err != nil { - return fmt.Errorf("failed to store JWT: %w", err) - } - - return tun.run(ctx, eventSink, local, rawJWT, retryCount+1) + return err + } else if errors.Is(err, errUnauthenticated) && retryCount == 0 { + serverURL := &url.URL{ + Scheme: "http", + Host: tun.cfg.proxyHost, + } + if tun.cfg.tlsConfig != nil { + serverURL.Scheme = "https" } - fallthrough - default: - _ = tun.cfg.jwtCache.DeleteJWT(tun.jwtCacheKey()) - return fmt.Errorf("invalid http response code: %d", res.StatusCode) - } - log.Println("connection established") - eventSink.OnConnected(ctx) + rawJWT, err = tun.auth.GetJWT(ctx, serverURL, func(authURL string) { + eventSink.OnAuthRequired(ctx, authURL) + }) + if err != nil { + return fmt.Errorf("failed to get authentication JWT: %w", err) + } - errc := make(chan error, 2) - go func() { - _, err := io.Copy(remote, local) - errc <- err - }() - remoteReader := deBuffer(br, remote) - go func() { - _, err := io.Copy(local, remoteReader) - errc <- err - }() + err = tun.cfg.jwtCache.StoreJWT(tun.jwtCacheKey(), rawJWT) + if err != nil { + return fmt.Errorf("failed to store JWT: %w", err) + } - select { - case err := <-errc: + return tun.run(ctx, eventSink, local, rawJWT, retryCount+1) + } else if err != nil { + _ = tun.cfg.jwtCache.DeleteJWT(tun.jwtCacheKey()) return err - case <-ctx.Done(): - return nil } + + return nil } func (tun *Tunnel) jwtCacheKey() string { return fmt.Sprintf("%s|%v", tun.cfg.proxyHost, tun.cfg.tlsConfig != nil) } - -func deBuffer(br *bufio.Reader, underlying io.Reader) io.Reader { - if br.Buffered() == 0 { - return underlying - } - return io.MultiReader(io.LimitReader(br, int64(br.Buffered())), underlying) -} diff --git a/tunnel/tunnel_http1.go b/tunnel/tunnel_http1.go new file mode 100644 index 0000000..6ef831e --- /dev/null +++ b/tunnel/tunnel_http1.go @@ -0,0 +1,117 @@ +package tunnel + +import ( + "bufio" + "context" + "crypto/tls" + "fmt" + "io" + "log" + "net" + "net/http" + "net/url" +) + +type http1tunneler struct { + cfg *config +} + +func (t *http1tunneler) TunnelTCP( + ctx context.Context, + eventSink EventSink, + local io.ReadWriter, + rawJWT string, +) error { + eventSink.OnConnecting(ctx) + + hdr := http.Header{} + if rawJWT != "" { + hdr.Set("Authorization", "Pomerium "+rawJWT) + } + + req := (&http.Request{ + Method: "CONNECT", + URL: &url.URL{Opaque: t.cfg.dstHost}, + Host: t.cfg.dstHost, + Header: hdr, + }).WithContext(ctx) + + var remote net.Conn + var err error + if t.cfg.tlsConfig != nil { + remote, err = (&tls.Dialer{Config: t.cfg.tlsConfig}).DialContext(ctx, "tcp", t.cfg.proxyHost) + } else { + remote, err = (&net.Dialer{}).DialContext(ctx, "tcp", t.cfg.proxyHost) + } + if err != nil { + return fmt.Errorf("failed to establish connection to proxy: %w", err) + } + defer func() { + _ = remote.Close() + log.Println("connection closed") + }() + if done := ctx.Done(); done != nil { + go func() { + <-done + _ = remote.Close() + }() + } + + err = req.Write(remote) + if err != nil { + return err + } + + br := bufio.NewReader(remote) + res, err := http.ReadResponse(br, req) + if err != nil { + return fmt.Errorf("failed to read HTTP response: %w", err) + } + defer func() { + _ = res.Body.Close() + }() + + switch res.StatusCode { + case http.StatusOK: + case http.StatusServiceUnavailable: + return errUnavailable + case http.StatusMovedPermanently, + http.StatusFound, + http.StatusTemporaryRedirect, + http.StatusPermanentRedirect: + return errUnauthenticated + default: + return fmt.Errorf("invalid http response code: %d", res.StatusCode) + } + + log.Println("connection established") + eventSink.OnConnected(ctx) + + errc := make(chan error, 2) + go func() { + _, err := io.Copy(remote, local) + errc <- err + }() + remoteReader := deBuffer(br, remote) + go func() { + _, err := io.Copy(local, remoteReader) + errc <- err + }() + + select { + case err = <-errc: + case <-ctx.Done(): + err = context.Cause(ctx) + } + + eventSink.OnDisconnected(ctx, err) + + return err +} + +func deBuffer(br *bufio.Reader, underlying io.Reader) io.Reader { + if br.Buffered() == 0 { + return underlying + } + return io.MultiReader(io.LimitReader(br, int64(br.Buffered())), underlying) +} diff --git a/tunnel/tunnel_http2.go b/tunnel/tunnel_http2.go new file mode 100644 index 0000000..aad58b3 --- /dev/null +++ b/tunnel/tunnel_http2.go @@ -0,0 +1,116 @@ +package tunnel + +import ( + "context" + "crypto/tls" + "fmt" + "io" + "log" + "net/http" + "net/url" + + "golang.org/x/net/http2" +) + +type http2tunneler struct { + cfg *config +} + +func (t *http2tunneler) TunnelTCP( + ctx context.Context, + eventSink EventSink, + local io.ReadWriter, + rawJWT string, +) error { + eventSink.OnConnecting(ctx) + + hdr := http.Header{} + if rawJWT != "" { + hdr.Set("Authorization", "Pomerium "+rawJWT) + } + + if t.cfg.tlsConfig == nil { + return fmt.Errorf("%w: http2 requires TLS", errUnsupported) + } + + cfg := t.cfg.tlsConfig.Clone() + cfg.NextProtos = []string{"h2"} + + raw, err := (&tls.Dialer{Config: cfg}).DialContext(ctx, "tcp", t.cfg.proxyHost) + if err != nil { + return fmt.Errorf("failed to establish connection to proxy: %w", err) + } + defer func() { + _ = raw.Close() + log.Println("connection closed") + }() + + remote, ok := raw.(*tls.Conn) + if !ok { + return fmt.Errorf("unexpected connection type returned from dial: %T", raw) + } + + protocol := remote.ConnectionState().NegotiatedProtocol + if protocol != "h2" { + return fmt.Errorf("%w: unexpected TLS protocol: %s", errUnsupported, protocol) + } + + cc, err := (&http2.Transport{}).NewClientConn(remote) + if err != nil { + return fmt.Errorf("failed to establish http2 connection: %w", err) + } + defer cc.Close() + + pr, pw := io.Pipe() + + req := (&http.Request{ + Method: "CONNECT", + URL: &url.URL{Opaque: t.cfg.dstHost}, + Host: t.cfg.dstHost, + Header: hdr, + Body: pr, + ContentLength: -1, + }).WithContext(ctx) + + res, err := cc.RoundTrip(req) + if err != nil { + return fmt.Errorf("error making http2 connect request: %w", err) + } + defer res.Body.Close() + + switch res.StatusCode { + case http.StatusOK: + case http.StatusServiceUnavailable: + return errUnavailable + case http.StatusMovedPermanently, + http.StatusFound, + http.StatusTemporaryRedirect, + http.StatusPermanentRedirect: + return errUnauthenticated + default: + return fmt.Errorf("invalid http response code: %d", res.StatusCode) + } + + log.Println("connection established via http2") + eventSink.OnConnected(ctx) + + errc := make(chan error, 2) + go func() { + _, err := io.Copy(pw, local) + errc <- err + }() + go func() { + _, err := io.Copy(local, res.Body) + errc <- err + }() + + select { + case err = <-errc: + case <-ctx.Done(): + err = context.Cause(ctx) + } + + eventSink.OnDisconnected(ctx, err) + + return err +} diff --git a/tunnel/tunnel_http2_test.go b/tunnel/tunnel_http2_test.go new file mode 100644 index 0000000..0840e37 --- /dev/null +++ b/tunnel/tunnel_http2_test.go @@ -0,0 +1,70 @@ +package tunnel + +import ( + "context" + "crypto/tls" + "io" + "net" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestTCPTunnelViaHTTP2(t *testing.T) { + t.Parallel() + + ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10) + defer clearTimeout() + + srv := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if !assert.Equal(t, "CONNECT", r.Method) { + return + } + if !assert.Equal(t, "Pomerium JWT", r.Header.Get("Authorization")) { + return + } + if !assert.Equal(t, "example.com:9999", r.Host) { + return + } + + defer r.Body.Close() + w.WriteHeader(http.StatusOK) + w.(http.Flusher).Flush() + + buf := make([]byte, 4) + _, err := io.ReadFull(r.Body, buf) + assert.NoError(t, err) + assert.Equal(t, []byte{1, 2, 3, 4}, buf) + + _, _ = w.Write([]byte{5, 6, 7, 8}) + })) + srv.EnableHTTP2 = true + srv.StartTLS() + + c1, c2 := net.Pipe() + go func() { + _, _ = c1.Write([]byte{1, 2, 3, 4}) + }() + go func() { + buf := make([]byte, 4) + _, err := io.ReadFull(c1, buf) + assert.NoError(t, err) + assert.Equal(t, []byte{5, 6, 7, 8}, buf) + _ = c1.Close() + }() + + tun := &http2tunneler{ + getConfig( + WithDestinationHost("example.com:9999"), + WithProxyHost(srv.Listener.Addr().String()), + WithTLSConfig(&tls.Config{ + InsecureSkipVerify: true, + }), + ), + } + err := tun.TunnelTCP(ctx, DiscardEvents(), c2, "JWT") + assert.NoError(t, err) +} diff --git a/tunnel/tunnel_http3.go b/tunnel/tunnel_http3.go new file mode 100644 index 0000000..d2d33ce --- /dev/null +++ b/tunnel/tunnel_http3.go @@ -0,0 +1,98 @@ +package tunnel + +import ( + "context" + "fmt" + "io" + "log" + "net/http" + "net/url" + + "github.com/quic-go/quic-go/http3" +) + +type http3tunneler struct { + cfg *config +} + +func (t *http3tunneler) TunnelTCP( + ctx context.Context, + eventSink EventSink, + local io.ReadWriter, + rawJWT string, +) error { + eventSink.OnConnecting(ctx) + + cfg := t.cfg.tlsConfig + if cfg == nil { + return fmt.Errorf("http/3: %w: TLS is required", errUnsupported) + } + cfg = cfg.Clone() + cfg.NextProtos = []string{http3.NextProtoH3} + + transport := (&http3.Transport{ + TLSClientConfig: cfg, + }) + defer func() { + transport.Close() + log.Println("connection closed") + }() + + pr, pw := io.Pipe() + + u, err := url.Parse("https://" + t.cfg.proxyHost) + if err != nil { + return fmt.Errorf("http/3: failed to parse proxy URL: %w", err) + } + hdr := http.Header{} + if rawJWT != "" { + hdr.Set("Authorization", "Pomerium "+rawJWT) + } + res, err := transport.RoundTrip(&http.Request{ + Method: http.MethodConnect, + URL: u, + Host: t.cfg.dstHost, + Header: hdr, + ContentLength: -1, + Body: pr, + }) + if err != nil { + return fmt.Errorf("http/3: %w: failed to make connect request: %w", errUnsupported, err) + } + defer res.Body.Close() + + switch res.StatusCode { + case http.StatusOK: + case http.StatusServiceUnavailable: + return errUnavailable + case http.StatusMovedPermanently, + http.StatusFound, + http.StatusTemporaryRedirect, + http.StatusPermanentRedirect: + return errUnauthenticated + default: + return fmt.Errorf("http/3: invalid response code: %d", res.StatusCode) + } + + log.Println("http/3: connection established") + eventSink.OnConnected(ctx) + + errc := make(chan error, 2) + go func() { + _, err := io.Copy(pw, local) + errc <- err + }() + go func() { + _, err := io.Copy(local, res.Body) + errc <- err + }() + + select { + case err = <-errc: + case <-ctx.Done(): + err = context.Cause(ctx) + } + + eventSink.OnDisconnected(ctx, err) + return err +} diff --git a/tunnel/tunnel_http3_test.go b/tunnel/tunnel_http3_test.go new file mode 100644 index 0000000..b4442c1 --- /dev/null +++ b/tunnel/tunnel_http3_test.go @@ -0,0 +1,134 @@ +package tunnel + +import ( + "context" + "crypto/tls" + "io" + "net" + "net/http" + "testing" + "time" + + "github.com/quic-go/quic-go/http3" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/pomerium/cli/internal/testutil" +) + +var testCert = []byte(`-----BEGIN CERTIFICATE----- +MIIDOTCCAiGgAwIBAgIQSRJrEpBGFc7tNb1fb5pKFzANBgkqhkiG9w0BAQsFADAS +MRAwDgYDVQQKEwdBY21lIENvMCAXDTcwMDEwMTAwMDAwMFoYDzIwODQwMTI5MTYw +MDAwWjASMRAwDgYDVQQKEwdBY21lIENvMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8A +MIIBCgKCAQEA6Gba5tHV1dAKouAaXO3/ebDUU4rvwCUg/CNaJ2PT5xLD4N1Vcb8r +bFSW2HXKq+MPfVdwIKR/1DczEoAGf/JWQTW7EgzlXrCd3rlajEX2D73faWJekD0U +aUgz5vtrTXZ90BQL7WvRICd7FlEZ6FPOcPlumiyNmzUqtwGhO+9ad1W5BqJaRI6P +YfouNkwR6Na4TzSj5BrqUfP0FwDizKSJ0XXmh8g8G9mtwxOSN3Ru1QFc61Xyeluk +POGKBV/q6RBNklTNe0gI8usUMlYyoC7ytppNMW7X2vodAelSu25jgx2anj9fDVZu +h7AXF5+4nJS4AAt0n1lNY7nGSsdZas8PbQIDAQABo4GIMIGFMA4GA1UdDwEB/wQE +AwICpDATBgNVHSUEDDAKBggrBgEFBQcDATAPBgNVHRMBAf8EBTADAQH/MB0GA1Ud +DgQWBBStsdjh3/JCXXYlQryOrL4Sh7BW5TAuBgNVHREEJzAlggtleGFtcGxlLmNv +bYcEfwAAAYcQAAAAAAAAAAAAAAAAAAAAATANBgkqhkiG9w0BAQsFAAOCAQEAxWGI +5NhpF3nwwy/4yB4i/CwwSpLrWUa70NyhvprUBC50PxiXav1TeDzwzLx/o5HyNwsv +cxv3HdkLW59i/0SlJSrNnWdfZ19oTcS+6PtLoVyISgtyN6DpkKpdG1cOkW3Cy2P2 ++tK/tKHRP1Y/Ra0RiDpOAmqn0gCOFGz8+lqDIor/T7MTpibL3IxqWfPrvfVRHL3B +grw/ZQTTIVjjh4JBSW3WyWgNo/ikC1lrVxzl4iPUGptxT36Cr7Zk2Bsg0XqwbOvK +5d+NTDREkSnUbie4GeutujmX3Dsx88UiV6UY/4lHJa6I5leHUNOHahRbpbWeOfs/ +WkBKOclmOV2xlTVuPw== +-----END CERTIFICATE-----`) + +var testKey = []byte(`-----BEGIN RSA PRIVATE KEY----- +MIIEvAIBADANBgkqhkiG9w0BAQEFAASCBKYwggSiAgEAAoIBAQDoZtrm0dXV0Aqi +4Bpc7f95sNRTiu/AJSD8I1onY9PnEsPg3VVxvytsVJbYdcqr4w99V3AgpH/UNzMS +gAZ/8lZBNbsSDOVesJ3euVqMRfYPvd9pYl6QPRRpSDPm+2tNdn3QFAvta9EgJ3sW +URnoU85w+W6aLI2bNSq3AaE771p3VbkGolpEjo9h+i42TBHo1rhPNKPkGupR8/QX +AOLMpInRdeaHyDwb2a3DE5I3dG7VAVzrVfJ6W6Q84YoFX+rpEE2SVM17SAjy6xQy +VjKgLvK2mk0xbtfa+h0B6VK7bmODHZqeP18NVm6HsBcXn7iclLgAC3SfWU1jucZK +x1lqzw9tAgMBAAECggEABWzxS1Y2wckblnXY57Z+sl6YdmLV+gxj2r8Qib7g4ZIk +lIlWR1OJNfw7kU4eryib4fc6nOh6O4AWZyYqAK6tqNQSS/eVG0LQTLTTEldHyVJL +dvBe+MsUQOj4nTndZW+QvFzbcm2D8lY5n2nBSxU5ypVoKZ1EqQzytFcLZpTN7d89 +EPj0qDyrV4NZlWAwL1AygCwnlwhMQjXEalVF1ylXwU3QzyZ/6MgvF6d3SSUlh+sq +XefuyigXw484cQQgbzopv6niMOmGP3of+yV4JQqUSb3IDmmT68XjGd2Dkxl4iPki +6ZwXf3CCi+c+i/zVEcufgZ3SLf8D99kUGE7v7fZ6AQKBgQD1ZX3RAla9hIhxCf+O +3D+I1j2LMrdjAh0ZKKqwMR4JnHX3mjQI6LwqIctPWTU8wYFECSh9klEclSdCa64s +uI/GNpcqPXejd0cAAdqHEEeG5sHMDt0oFSurL4lyud0GtZvwlzLuwEweuDtvT9cJ +Wfvl86uyO36IW8JdvUprYDctrQKBgQDycZ697qutBieZlGkHpnYWUAeImVA878sJ +w44NuXHvMxBPz+lbJGAg8Cn8fcxNAPqHIraK+kx3po8cZGQywKHUWsxi23ozHoxo ++bGqeQb9U661TnfdDspIXia+xilZt3mm5BPzOUuRqlh4Y9SOBpSWRmEhyw76w4ZP +OPxjWYAgwQKBgA/FehSYxeJgRjSdo+MWnK66tjHgDJE8bYpUZsP0JC4R9DL5oiaA +brd2fI6Y+SbyeNBallObt8LSgzdtnEAbjIH8uDJqyOmknNePRvAvR6mP4xyuR+Bv +m+Lgp0DMWTw5J9CKpydZDItc49T/mJ5tPhdFVd+am0NAQnmr1MCZ6nHxAoGABS3Y +LkaC9FdFUUqSU8+Chkd/YbOkuyiENdkvl6t2e52jo5DVc1T7mLiIrRQi4SI8N9bN +/3oJWCT+uaSLX2ouCtNFunblzWHBrhxnZzTeqVq4SLc8aESAnbslKL4i8/+vYZlN +s8xtiNcSvL+lMsOBORSXzpj/4Ot8WwTkn1qyGgECgYBKNTypzAHeLE6yVadFp3nQ +Ckq9yzvP/ib05rvgbvrne00YeOxqJ9gtTrzgh7koqJyX1L4NwdkEza4ilDWpucn0 +xiUZS4SoaJq6ZvcBYS62Yr1t8n09iG47YL8ibgtmH3L+svaotvpVxVK+d7BLevA/ +ZboOWVe3icTy64BT3OQhmg== +-----END RSA PRIVATE KEY-----`) + +func TestTCPTunnelViaHTTP3(t *testing.T) { + t.Parallel() + + ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10) + defer clearTimeout() + + port := testutil.GetPort(t) + + cert, err := tls.X509KeyPair(testCert, testKey) + require.NoError(t, err) + + srv := &http3.Server{ + Addr: "127.0.0.1:" + port, + TLSConfig: &tls.Config{ + Certificates: []tls.Certificate{cert}, + }, + Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if !assert.Equal(t, "CONNECT", r.Method) { + return + } + if !assert.Equal(t, "Pomerium JWT", r.Header.Get("Authorization")) { + return + } + if !assert.Equal(t, "example.com:9999", r.Host) { + return + } + + defer r.Body.Close() + w.WriteHeader(http.StatusOK) + w.(http.Flusher).Flush() + + buf := make([]byte, 4) + _, err := io.ReadFull(r.Body, buf) + assert.NoError(t, err) + assert.Equal(t, []byte{1, 2, 3, 4}, buf) + + _, _ = w.Write([]byte{5, 6, 7, 8}) + }), + } + t.Cleanup(func() { srv.Close() }) + go func() { _ = srv.ListenAndServe() }() + + c1, c2 := net.Pipe() + go func() { + _, _ = c1.Write([]byte{1, 2, 3, 4}) + }() + go func() { + buf := make([]byte, 4) + _, err := io.ReadFull(c1, buf) + assert.NoError(t, err) + assert.Equal(t, []byte{5, 6, 7, 8}, buf) + _ = c1.Close() + }() + + tun := &http3tunneler{ + getConfig( + WithDestinationHost("example.com:9999"), + WithProxyHost("127.0.0.1:"+port), + WithTLSConfig(&tls.Config{ + InsecureSkipVerify: true, + }), + ), + } + err = tun.TunnelTCP(ctx, DiscardEvents(), c2, "JWT") + assert.NoError(t, err) +} diff --git a/tunnel/tunnel_tcp.go b/tunnel/tunnel_tcp.go new file mode 100644 index 0000000..fdec3fc --- /dev/null +++ b/tunnel/tunnel_tcp.go @@ -0,0 +1,61 @@ +package tunnel + +import ( + "context" + "io" + "log" + "net/http" + "strings" +) + +// A TCPTunneler tunnels TCP traffic. +type TCPTunneler interface { + TunnelTCP( + ctx context.Context, + eventSink EventSink, + local io.ReadWriter, + rawJWT string, + ) error +} + +// PickTCPTunneler picks a tcp tunneler for the given proxy. +func (tun *Tunnel) pickTCPTunneler(ctx context.Context) TCPTunneler { + fallback := &http1tunneler{cfg: tun.cfg} + + // if we're not using TLS, only HTTP1 is supported + if tun.cfg.tlsConfig == nil { + log.Println("pick-tcp-tunneler: tls not enabled, using http1") + return fallback + } + + client := &http.Client{ + Transport: &http.Transport{ + ForceAttemptHTTP2: true, + TLSClientConfig: tun.cfg.tlsConfig, + }, + } + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, "https://"+tun.cfg.proxyHost, nil) + if err != nil { + log.Println("pick-tcp-tunneler: failed to create probe request, falling back to http1", err) + return fallback + } + + res, err := client.Do(req) + if err != nil { + log.Println("pick-tcp-tunneler: failed to make probe request, falling back to http1", err) + return fallback + } + res.Body.Close() + + if v := res.Header.Get("Alt-Svc"); strings.Contains(v, "h3") { + log.Println("pick-tcp-tunneler: using http3") + return &http3tunneler{cfg: tun.cfg} + } else if res.ProtoMajor == 2 { + log.Println("pick-tcp-tunneler: using http2") + return &http2tunneler{cfg: tun.cfg} + } + + log.Println("pick-tcp-tunneler: using http1") + return fallback +}