Skip to content

Commit

Permalink
abstract out context and enrollment id header setting
Browse files Browse the repository at this point in the history
  • Loading branch information
jessepeterson committed Aug 27, 2023
1 parent fc9260c commit ee33b3c
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 17 deletions.
11 changes: 10 additions & 1 deletion cmd/nanomdm/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@ const (
endpointAPIVersion = "/version"
)

const (
EnrollmentIDHeader = "X-Enrollment-ID"
TraceIDHeader = "X-Trace-ID"
)

func main() {
cliStorage := cli.NewStorage()
flag.Var(&cliStorage.Storage, "storage", "name of storage backend")
Expand Down Expand Up @@ -166,7 +171,11 @@ func main() {

if *flAuthProxy != "" {
var authProxyHandler http.Handler
authProxyHandler, err = authproxy.New(*flAuthProxy, logger.With("handler", "authproxy"))
authProxyHandler, err = authproxy.New(*flAuthProxy,
authproxy.WithLogger(logger.With("handler", "authproxy")),
authproxy.WithHeaderFunc(EnrollmentIDHeader, httpmdm.GetEnrollmentID),
authproxy.WithHeaderFunc(TraceIDHeader, mdmhttp.GetTraceID),
)
if err != nil {
stdlog.Fatal(err)
}
Expand Down
70 changes: 54 additions & 16 deletions http/authproxy/authproxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,49 +2,87 @@
package authproxy

import (
"context"
"net/http"
"net/http/httputil"
"net/url"

mdmhttp "github.com/micromdm/nanomdm/http"
httpmdm "github.com/micromdm/nanomdm/http/mdm"
"github.com/micromdm/nanomdm/log"
"github.com/micromdm/nanomdm/log/ctxlog"
)

const (
EnrollmentIDHeader = "X-Enrollment-ID"
TraceIDHeader = "X-Trace-ID"
)
// HeaderFunc takes an HTTP request and returns a string value.
// Ostensibly to be set in a header on the proxy target.
type HeaderFunc func(context.Context) string

type config struct {
logger log.Logger
fwdSig bool
headerFuncs map[string]HeaderFunc
}

type Option func(*config)

// WithLogger sets a logger for error reporting.
func WithLogger(logger log.Logger) Option {
return func(c *config) {
c.logger = logger
}
}

// WithHeaderFunc configures fn to be called and added as an HTTP header to the proxy target request.
func WithHeaderFunc(header string, fn HeaderFunc) Option {
return func(c *config) {
c.headerFuncs[header] = fn
}
}

// WithForwardMDMSignature forwards the MDM-Signature header onto the proxy destination.
// This option is off by default because the header adds about two kilobytes to the request.
func WithForwardMDMSignature() Option {
return func(c *config) {
c.fwdSig = true
}
}

// New creates a new NanoMDM enrollment authenticating reverse proxy.
// This reverse proxy is mostly the standard httputil proxy. It depends
// on middleware HTTP handlers to enforce authentication and set the
// context value for the enrollment ID.
func New(dest string, logger log.Logger) (*httputil.ReverseProxy, error) {
func New(dest string, opts ...Option) (*httputil.ReverseProxy, error) {
config := &config{
logger: log.NopLogger,
headerFuncs: make(map[string]HeaderFunc),
}
for _, opt := range opts {
opt(config)
}
target, err := url.Parse(dest)
if err != nil {
return nil, err
}
proxy := httputil.NewSingleHostReverseProxy(target)
proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
ctxlog.Logger(r.Context(), logger).Info("err", err)
ctxlog.Logger(r.Context(), config.logger).Info("err", err)
// use the same error as the standrad reverse proxy
w.WriteHeader(http.StatusBadGateway)
}
dir := proxy.Director
proxy.Director = func(req *http.Request) {
dir(req)
req.Host = target.Host
// save the effort of forwarding this huge header
req.Header.Del("Mdm-Signature")
if id := httpmdm.GetEnrollmentID(req.Context()); id != "" {
req.Header.Set(EnrollmentIDHeader, id)
if !config.fwdSig {
// save the effort of forwarding this huge header
req.Header.Del("Mdm-Signature")
}
// TODO: this couples us to our specific idea of trace logging
// Perhaps have an optional config for header specificaiton?
if id := mdmhttp.GetTraceID(req.Context()); id != "" {
req.Header.Set(TraceIDHeader, id)
// set any headers we want to forward.
for k, fn := range config.headerFuncs {
if k == "" || fn == nil {
continue
}
if v := fn(req.Context()); v != "" {
req.Header.Set(k, v)
}
}
}
return proxy, nil
Expand Down

0 comments on commit ee33b3c

Please sign in to comment.