Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

http2 connect support #472

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion api/tunnel.go
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,6 @@ func tunnelAcceptLoop(ctx context.Context, id string, li net.Listener, tun Tunne
if err != nil {
log.Printf("error serving local connection %s: %v\n", id, err)
}
cEvt.OnDisconnected(ctx, err)
}(c)
}
}
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ require (
github.com/spf13/cobra v1.8.1
github.com/stretchr/testify v1.10.0
golang.org/x/crypto v0.29.0
golang.org/x/net v0.30.0
golang.org/x/sync v0.9.0
golang.org/x/sys v0.27.0
google.golang.org/grpc v1.68.0
Expand Down Expand Up @@ -57,7 +58,6 @@ require (
go.uber.org/multierr v1.11.0 // indirect
go.uber.org/zap v1.27.0 // indirect
golang.org/x/mod v0.20.0 // indirect
golang.org/x/net v0.30.0 // indirect
golang.org/x/text v0.20.0 // indirect
golang.org/x/tools v0.24.0 // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20241007155032-5fefd90f89a9 // indirect
Expand Down
141 changes: 34 additions & 107 deletions tunnel/tunnel.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,12 @@
package tunnel

import (
"bufio"
"context"
"crypto/tls"
"errors"
"fmt"
"io"
"log"
"net"
"net/http"
"net/url"
"time"

Expand All @@ -20,6 +17,12 @@ import (
"github.com/pomerium/cli/jwt"
)

var (
errUnavailable = errors.New("unavailable")
errUnauthenticated = errors.New("unauthenticated")
errUnsupported = errors.New("unsupported")
)

// A Tunnel represents a TCP tunnel over HTTP Connect.
type Tunnel struct {
cfg *config
Expand Down Expand Up @@ -106,121 +109,45 @@ func (tun *Tunnel) Run(ctx context.Context, local io.ReadWriter, eventSink Event
}

func (tun *Tunnel) run(ctx context.Context, eventSink EventSink, local io.ReadWriter, rawJWT string, retryCount int) error {
eventSink.OnConnecting(ctx)

hdr := http.Header{}
if rawJWT != "" {
hdr.Set("Authorization", "Pomerium "+rawJWT)
}

req := (&http.Request{
Method: "CONNECT",
URL: &url.URL{Opaque: tun.cfg.dstHost},
Host: tun.cfg.dstHost,
Header: hdr,
}).WithContext(ctx)

var remote net.Conn
var err error
if tun.cfg.tlsConfig != nil {
remote, err = (&tls.Dialer{Config: tun.cfg.tlsConfig}).DialContext(ctx, "tcp", tun.cfg.proxyHost)
} else {
remote, err = (&net.Dialer{}).DialContext(ctx, "tcp", tun.cfg.proxyHost)
}
if err != nil {
return fmt.Errorf("failed to establish connection to proxy: %w", err)
}
defer func() {
_ = remote.Close()
log.Println("connection closed")
}()
if done := ctx.Done(); done != nil {
go func() {
<-done
_ = remote.Close()
}()
}

err = req.Write(remote)
if err != nil {
return err
err := (&http2tunnel{cfg: tun.cfg}).TunnelTCP(ctx, eventSink, local, rawJWT)
if errors.Is(err, errUnsupported) {
// fallback to http1
err = (&http1tunnel{cfg: tun.cfg}).TunnelTCP(ctx, eventSink, local, rawJWT)
}

br := bufio.NewReader(remote)
res, err := http.ReadResponse(br, req)
if err != nil {
return fmt.Errorf("failed to read HTTP response: %w", err)
}
defer func() {
_ = res.Body.Close()
}()
switch res.StatusCode {
case http.StatusOK:
case http.StatusServiceUnavailable:
if errors.Is(err, errUnavailable) {
// don't delete the JWT if we get a service unavailable
return fmt.Errorf("invalid http response code: %s", res.Status)
case http.StatusMovedPermanently,
http.StatusFound,
http.StatusTemporaryRedirect,
http.StatusPermanentRedirect:
if retryCount == 0 {
_ = remote.Close()

serverURL := &url.URL{
Scheme: "http",
Host: tun.cfg.proxyHost,
}
if tun.cfg.tlsConfig != nil {
serverURL.Scheme = "https"
}

rawJWT, err = tun.auth.GetJWT(ctx, serverURL, func(authURL string) { eventSink.OnAuthRequired(ctx, authURL) })
if err != nil {
return fmt.Errorf("failed to get authentication JWT: %w", err)
}

err = tun.cfg.jwtCache.StoreJWT(tun.jwtCacheKey(), rawJWT)
if err != nil {
return fmt.Errorf("failed to store JWT: %w", err)
}

return tun.run(ctx, eventSink, local, rawJWT, retryCount+1)
return err
} else if errors.Is(err, errUnauthenticated) && retryCount == 0 {
serverURL := &url.URL{
Scheme: "http",
Host: tun.cfg.proxyHost,
}
if tun.cfg.tlsConfig != nil {
serverURL.Scheme = "https"
}
fallthrough
default:
_ = tun.cfg.jwtCache.DeleteJWT(tun.jwtCacheKey())
return fmt.Errorf("invalid http response code: %d", res.StatusCode)
}

log.Println("connection established")
eventSink.OnConnected(ctx)
rawJWT, err = tun.auth.GetJWT(ctx, serverURL, func(authURL string) {
eventSink.OnAuthRequired(ctx, authURL)
})
if err != nil {
return fmt.Errorf("failed to get authentication JWT: %w", err)
}

errc := make(chan error, 2)
go func() {
_, err := io.Copy(remote, local)
errc <- err
}()
remoteReader := deBuffer(br, remote)
go func() {
_, err := io.Copy(local, remoteReader)
errc <- err
}()
err = tun.cfg.jwtCache.StoreJWT(tun.jwtCacheKey(), rawJWT)
if err != nil {
return fmt.Errorf("failed to store JWT: %w", err)
}

select {
case err := <-errc:
return tun.run(ctx, eventSink, local, rawJWT, retryCount+1)
} else if err != nil {
_ = tun.cfg.jwtCache.DeleteJWT(tun.jwtCacheKey())
return err
case <-ctx.Done():
return nil
}

return nil
}

func (tun *Tunnel) jwtCacheKey() string {
return fmt.Sprintf("%s|%v", tun.cfg.proxyHost, tun.cfg.tlsConfig != nil)
}

func deBuffer(br *bufio.Reader, underlying io.Reader) io.Reader {
if br.Buffered() == 0 {
return underlying
}
return io.MultiReader(io.LimitReader(br, int64(br.Buffered())), underlying)
}
117 changes: 117 additions & 0 deletions tunnel/tunnel_http1.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
package tunnel

import (
"bufio"
"context"
"crypto/tls"
"fmt"
"io"
"log"
"net"
"net/http"
"net/url"
)

type http1tunnel struct {
cfg *config
}

func (t *http1tunnel) TunnelTCP(
ctx context.Context,
eventSink EventSink,
local io.ReadWriter,
rawJWT string,
) error {
eventSink.OnConnecting(ctx)

hdr := http.Header{}
if rawJWT != "" {
hdr.Set("Authorization", "Pomerium "+rawJWT)
}

req := (&http.Request{
Method: "CONNECT",
URL: &url.URL{Opaque: t.cfg.dstHost},
Host: t.cfg.dstHost,
Header: hdr,
}).WithContext(ctx)

var remote net.Conn
var err error
if t.cfg.tlsConfig != nil {
remote, err = (&tls.Dialer{Config: t.cfg.tlsConfig}).DialContext(ctx, "tcp", t.cfg.proxyHost)
} else {
remote, err = (&net.Dialer{}).DialContext(ctx, "tcp", t.cfg.proxyHost)
}
if err != nil {
return fmt.Errorf("failed to establish connection to proxy: %w", err)
}
defer func() {
_ = remote.Close()
log.Println("connection closed")
}()
if done := ctx.Done(); done != nil {
go func() {
<-done
_ = remote.Close()
}()
}

err = req.Write(remote)
if err != nil {
return err
}

br := bufio.NewReader(remote)
res, err := http.ReadResponse(br, req)
if err != nil {
return fmt.Errorf("failed to read HTTP response: %w", err)
}
defer func() {
_ = res.Body.Close()
}()

switch res.StatusCode {
case http.StatusOK:
case http.StatusServiceUnavailable:
return errUnavailable
case http.StatusMovedPermanently,
http.StatusFound,
http.StatusTemporaryRedirect,
http.StatusPermanentRedirect:
return errUnauthenticated
default:
return fmt.Errorf("invalid http response code: %d", res.StatusCode)
}

log.Println("connection established")
eventSink.OnConnected(ctx)

errc := make(chan error, 2)
go func() {
_, err := io.Copy(remote, local)
errc <- err
}()
remoteReader := deBuffer(br, remote)
go func() {
_, err := io.Copy(local, remoteReader)
errc <- err
}()

select {
case err = <-errc:
case <-ctx.Done():
err = context.Cause(ctx)
}

eventSink.OnDisconnected(ctx, err)

return err
}

func deBuffer(br *bufio.Reader, underlying io.Reader) io.Reader {
if br.Buffered() == 0 {
return underlying
}
return io.MultiReader(io.LimitReader(br, int64(br.Buffered())), underlying)
}
Loading
Loading