Skip to content

Commit

Permalink
auth.go: find tsnet.Listener using reflection
Browse files Browse the repository at this point in the history
Signed-off-by: hello <[email protected]>
  • Loading branch information
beep-beep-beep-boop committed Jun 18, 2024
1 parent 726f8df commit d8d2947
Showing 1 changed file with 40 additions and 7 deletions.
47 changes: 40 additions & 7 deletions auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@ package tscaddy

import (
"fmt"
"net"
"net/http"
"reflect"
"strings"

"github.com/caddyserver/caddy/v2"
Expand Down Expand Up @@ -40,6 +42,39 @@ func (Auth) CaddyModule() caddy.ModuleInfo {
}
}

// findTsnetListener recursively searches ln for embedded child net.Listener structs
// until it finds the tsnetListener or runs out.
// the 2nd return value is true if it was found, false if it wasn't.
func findTsnetListener(ln net.Listener) (tsnetListener, bool) {
// if the input is a tsnetListener, return it.
if tsn, ok := ln.(tsnetListener); ok {
return tsn, true
}

s := reflect.ValueOf(ln)
// make sure s is a struct instead of a pointer to one
if s.Kind() == reflect.Ptr {
s = s.Elem()
}

if s.Kind() != reflect.Struct {
return nil, false
}

innerLn := s.FieldByName("Listener")
if innerLn.IsZero() {
// no more child/embedded listeners left
return nil, false
}

// if the "Listener" child is a net.Listener, run the function again on its child.
if child, ok := innerLn.Interface().(net.Listener); ok {
return findTsnetListener(child)
}

return nil, false
}

// client returns the tailscale LocalClient for the TailscaleAuth module.
// If the LocalClient has not already been configured, the provided request will be used to
// lookup the tailscale node that serviced the request, and get the associated LocalClient.
Expand All @@ -52,13 +87,11 @@ func (ta *Auth) client(r *http.Request) (*tailscale.LocalClient, error) {
// server.
server := r.Context().Value(caddyhttp.ServerCtxKey).(*caddyhttp.Server)
for _, listener := range server.Listeners() {
if tsServerListener, ok := listener.(*tsnetServerListener); ok {
if tsl, ok := tsServerListener.Listener.(tsnetListener); ok {
var err error
ta.localclient, err = tsl.Server().LocalClient()
if err != nil {
return nil, err
}
if tsl, ok := findTsnetListener(listener); ok {
var err error
ta.localclient, err = tsl.Server().LocalClient()
if err != nil {
return nil, err
}
}
}
Expand Down

0 comments on commit d8d2947

Please sign in to comment.