diff --git a/hscontrol/app.go b/hscontrol/app.go index 939552ad72..934e0f33e0 100644 --- a/hscontrol/app.go +++ b/hscontrol/app.go @@ -217,9 +217,16 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) { ) } + if cfg.DERP.ServerVerifyClients { + t := http.DefaultTransport.(*http.Transport) + t.RegisterProtocol( + derpServer.DerpVerifyScheme, + derpServer.NewDERPVerifyTransport(app.handleVerifyRequest), + ) + } + embeddedDERPServer, err := derpServer.NewDERPServer( cfg.ServerURL, - app.VerifyHandler, key.NodePrivate(*derpServerKey), &cfg.DERP, ) diff --git a/hscontrol/derp/server/derp_server.go b/hscontrol/derp/server/derp_server.go index 8c54e2ab1f..fa37a1abaa 100644 --- a/hscontrol/derp/server/derp_server.go +++ b/hscontrol/derp/server/derp_server.go @@ -2,12 +2,13 @@ package server import ( "bufio" + "bytes" "context" "encoding/json" "fmt" + "io" "net" "net/http" - "net/http/httptest" "net/netip" "net/url" "strconv" @@ -30,6 +31,7 @@ import ( // headers and it will begin writing & reading the DERP protocol immediately // following its HTTP request. const fastStartHeader = "Derp-Fast-Start" +const DerpVerifyScheme = "derp-verify" type DERPServer struct { serverURL string @@ -40,7 +42,6 @@ type DERPServer struct { func NewDERPServer( serverURL string, - verifyHandler http.HandlerFunc, derpKey key.NodePrivate, cfg *types.DERPConfig, ) (*DERPServer, error) { @@ -48,11 +49,7 @@ func NewDERPServer( server := derp.NewServer(derpKey, util.TSLogfWrapper()) // nolint // zerolinter complains if cfg.ServerVerifyClients { - t := http.DefaultTransport.(*http.Transport) - t.RegisterProtocol("headscale", &HeadscaleTransport{ - verifyHandler: verifyHandler, - }) - server.SetVerifyClientURL("headscale://verify") + server.SetVerifyClientURL(DerpVerifyScheme + "://verify") server.SetVerifyClientURLFailOpen(false) } @@ -372,13 +369,31 @@ func serverSTUNListener(ctx context.Context, packetConn *net.UDPConn) { } } -type HeadscaleTransport struct { - verifyHandler http.HandlerFunc +func NewDERPVerifyTransport(handleVerifyRequest func(*http.Request, io.Writer) error) *DERPVerifyTransport { + return &DERPVerifyTransport{ + handleVerifyRequest: handleVerifyRequest, + } +} + +type DERPVerifyTransport struct { + handleVerifyRequest func(*http.Request, io.Writer) error } -func (t *HeadscaleTransport) RoundTrip(req *http.Request) (*http.Response, error) { - recorder := httptest.NewRecorder() - t.verifyHandler(recorder, req) - resp := recorder.Result() +func (t *DERPVerifyTransport) RoundTrip(req *http.Request) (*http.Response, error) { + buf := new(bytes.Buffer) + if err := t.handleVerifyRequest(req, buf); err != nil { + log.Error(). + Caller(). + Err(err). + Msg("Failed to handle verify request") + + return nil, err + } + + resp := &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(buf), + } + return resp, nil } diff --git a/hscontrol/handlers.go b/hscontrol/handlers.go index 3858df9339..87766810a2 100644 --- a/hscontrol/handlers.go +++ b/hscontrol/handlers.go @@ -59,26 +59,35 @@ func parseCabailityVersion(req *http.Request) (tailcfg.CapabilityVersion, error) func (h *Headscale) handleVerifyRequest( req *http.Request, -) (bool, error) { + writer io.Writer, +) error { body, err := io.ReadAll(req.Body) if err != nil { - return false, fmt.Errorf("cannot read request body: %w", err) + return fmt.Errorf("cannot read request body: %w", err) } var derpAdmitClientRequest tailcfg.DERPAdmitClientRequest if err := json.Unmarshal(body, &derpAdmitClientRequest); err != nil { - return false, fmt.Errorf("cannot parse derpAdmitClientRequest: %w", err) + return fmt.Errorf("cannot parse derpAdmitClientRequest: %w", err) } nodes, err := h.db.ListNodes() if err != nil { - return false, fmt.Errorf("cannot list nodes: %w", err) + return fmt.Errorf("cannot list nodes: %w", err) } - return nodes.ContainsNodeKey(derpAdmitClientRequest.NodePublic), nil + resp := &tailcfg.DERPAdmitClientResponse{ + Allow: nodes.ContainsNodeKey(derpAdmitClientRequest.NodePublic), + } + if err = json.NewEncoder(writer).Encode(resp); err != nil { + return fmt.Errorf("cannot encode response: %w", err) + } + + return nil } -// see https://github.com/tailscale/tailscale/blob/964282d34f06ecc06ce644769c66b0b31d118340/derp/derp_server.go#L1159, Derp use verifyClientsURL to verify whether a client is allowed to connect to the DERP server. +// VerifyHandler see https://github.com/tailscale/tailscale/blob/964282d34f06ecc06ce644769c66b0b31d118340/derp/derp_server.go#L1159, +// DERP use verifyClientsURL to verify whether a client is allowed to connect to the DERP server. func (h *Headscale) VerifyHandler( writer http.ResponseWriter, req *http.Request, @@ -92,28 +101,18 @@ func (h *Headscale) VerifyHandler( Str("handler", "/verify"). Msg("verify client") - allow, err := h.handleVerifyRequest(req) - if err != nil { + if err := h.handleVerifyRequest(req, writer); err != nil { log.Error(). Caller(). Err(err). Msg("Failed to verify client") http.Error(writer, "Internal error", http.StatusInternalServerError) - } - resp := tailcfg.DERPAdmitClientResponse{ - Allow: allow, + return } writer.Header().Set("Content-Type", "application/json") writer.WriteHeader(http.StatusOK) - err = json.NewEncoder(writer).Encode(resp) - if err != nil { - log.Error(). - Caller(). - Err(err). - Msg("Failed to write response") - } } // KeyHandler provides the Headscale pub key