diff --git a/auth.go b/auth.go index 66df968..d9dd913 100644 --- a/auth.go +++ b/auth.go @@ -7,7 +7,9 @@ package tscaddy import ( "fmt" + "net" "net/http" + "reflect" "strings" "github.com/caddyserver/caddy/v2" @@ -40,6 +42,39 @@ func (Auth) CaddyModule() caddy.ModuleInfo { } } +// findTsnetListener recursively searches ln for embedded child net.Listener structs +// until it finds the tsnetListener or runs out. +// the 2nd return value is true if it was found, false if it wasn't. +func findTsnetListener(ln net.Listener) (tsnetListener, bool) { + // if the input is a tsnetListener, return it. + if tsn, ok := ln.(tsnetListener); ok { + return tsn, true + } + + s := reflect.ValueOf(ln) + // make sure s is a struct instead of a pointer to one + if s.Kind() == reflect.Ptr { + s = s.Elem() + } + + if s.Kind() != reflect.Struct { + return nil, false + } + + innerLn := s.FieldByName("Listener") + if innerLn.IsZero() { + // no more child/embedded listeners left + return nil, false + } + + // if the "Listener" child is a net.Listener, run the function again on its child. + if child, ok := innerLn.Interface().(net.Listener); ok { + return findTsnetListener(child) + } + + return nil, false +} + // client returns the tailscale LocalClient for the TailscaleAuth module. // If the LocalClient has not already been configured, the provided request will be used to // lookup the tailscale node that serviced the request, and get the associated LocalClient. @@ -52,13 +87,11 @@ func (ta *Auth) client(r *http.Request) (*tailscale.LocalClient, error) { // server. server := r.Context().Value(caddyhttp.ServerCtxKey).(*caddyhttp.Server) for _, listener := range server.Listeners() { - if tsServerListener, ok := listener.(*tsnetServerListener); ok { - if tsl, ok := tsServerListener.Listener.(tsnetListener); ok { - var err error - ta.localclient, err = tsl.Server().LocalClient() - if err != nil { - return nil, err - } + if tsl, ok := findTsnetListener(listener); ok { + var err error + ta.localclient, err = tsl.Server().LocalClient() + if err != nil { + return nil, err } } }