Skip to content

Commit

Permalink
Revised my basic web server and moved it into the xhttp package
Browse files Browse the repository at this point in the history
  • Loading branch information
richardwilkes committed Jan 7, 2024
1 parent 6ed291c commit d46053c
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 191 deletions.
53 changes: 24 additions & 29 deletions xio/network/xhttp/basic_auth.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright ©2016-2022 by Richard A. Wilkes. All rights reserved.
// Copyright ©2016-2024 by Richard A. Wilkes. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, version 2.0. If a copy of the MPL was not distributed with
Expand All @@ -11,41 +11,36 @@
package xhttp

import (
"crypto/subtle"
"fmt"
"net/http"
)

// PasswordLookup provides a way to map a user in a realm to a password
type PasswordLookup func(user, realm string) string

// BasicAuth provides basic HTTP authentication.
type BasicAuth struct {
realm string
lookup PasswordLookup
}

// NewBasicAuth creates a new BasicAuth.
func NewBasicAuth(realm string, lookup PasswordLookup) *BasicAuth {
return &BasicAuth{realm: realm, lookup: lookup}
}

// Wrap an http.Handler.
func (auth *BasicAuth) Wrap(handler http.Handler) http.Handler {
return &wrapper{auth: auth, handler: handler}
}

type wrapper struct {
auth *BasicAuth
handler http.Handler
Realm string
// Lookup provides a way to map a user in a realm to a password. The returned password should have already been
// passed through the Hasher function.
Lookup func(user, realm string) ([]byte, bool)
Hasher func(input string) []byte
}

func (hw *wrapper) ServeHTTP(w http.ResponseWriter, req *http.Request) {
if user, pw, ok := req.BasicAuth(); ok {
if pw == hw.auth.lookup(user, hw.auth.realm) {
hw.handler.ServeHTTP(w, req)
return
// Wrap an http.Handler, requiring Basic Authentication.
func (ba *BasicAuth) Wrap(handler http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if user, pw, ok := r.BasicAuth(); ok {
stored, found := ba.Lookup(user, ba.Realm)
passwordMatch := subtle.ConstantTimeCompare(ba.Hasher(pw), stored) == 1
if found && passwordMatch {
if md := MetadataFromRequest(r); md != nil {
md.User = user
md.Logger = md.Logger.With("user", user)
}
handler.ServeHTTP(w, r)
return
}
}
}
w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Basic realm=%q`, hw.auth.realm))
WriteHTTPStatus(w, http.StatusUnauthorized)
w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Basic realm=%q, charset="UTF-8"`, ba.Realm))
ErrorStatus(w, http.StatusUnauthorized)
})
}
129 changes: 67 additions & 62 deletions xio/network/xhttp/web/server.go → xio/network/xhttp/server.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright ©2016-2023 by Richard A. Wilkes. All rights reserved.
// Copyright ©2016-2024 by Richard A. Wilkes. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, version 2.0. If a copy of the MPL was not distributed with
Expand All @@ -7,9 +7,7 @@
// This Source Code Form is "Incompatible With Secondary Licenses", as
// defined by the Mozilla Public License, version 2.0.

// Package web provides a web server with some standardized logging and
// handler wrapping.
package web
package xhttp

import (
"context"
Expand All @@ -18,15 +16,13 @@ import (
"log/slog"
"net"
"net/http"
"path"
"strconv"
"strings"
"time"

"github.com/richardwilkes/toolbox/atexit"
"github.com/richardwilkes/toolbox/errs"
"github.com/richardwilkes/toolbox/xio/network"
"github.com/richardwilkes/toolbox/xio/network/xhttp"
)

// Constants for protocols the server can provide.
Expand All @@ -35,18 +31,28 @@ const (
ProtocolHTTPS = "https"
)

type ctxKey int

const metadataKey ctxKey = 1

// Metadata holds auxiliary information for a request.
type Metadata struct {
Logger *slog.Logger
User string
}

// Server holds the data necessary for the server.
type Server struct {
WebServer *http.Server
Logger *slog.Logger
CertFile string
KeyFile string
ShutdownGracePeriod time.Duration
Logger *slog.Logger
WebServer *http.Server
Ports []int
ShutdownCallback func()
StartedChan chan any // If not nil, will be closed once the server is ready to accept connections
ShutdownGracePeriod time.Duration
ShutdownCallback func(*slog.Logger)
addresses []string
port int
clientHandler http.Handler
}

// Protocol returns the protocol this server is handling.
Expand Down Expand Up @@ -91,66 +97,31 @@ func (s *Server) Run() error {
if s.Logger == nil {
s.Logger = slog.Default()
}
handler := s.WebServer.Handler
s.WebServer.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
started := time.Now()
req.URL.Path = path.Clean(req.URL.Path)
req = req.WithContext(context.WithValue(req.Context(), routeKey, &route{path: req.URL.Path}))
sw := &xhttp.StatusResponseWriter{
Original: w,
Head: req.Method == http.MethodHead,
}
defer func() {
if recovered := recover(); recovered != nil {
err, ok := recovered.(error)
if !ok {
err = errs.Newf("%+v", recovered)
}
errs.LogTo(s.Logger, errs.NewWithCause("recovered from panic in handler", err))
sw.WriteHeader(http.StatusInternalServerError)
}
since := time.Since(started)
millis := int64(since / time.Millisecond)
micros := int64(since/time.Microsecond) - millis*1000
written := sw.BytesWritten()
s.Logger.Info("web", "status", sw.Status(), "elapsed", fmt.Sprintf("%d.%03dms", millis, micros),
"bytes", written, "method", req.Method, "url", req.URL)
}()
handler.ServeHTTP(sw, req)
})
s.clientHandler = s.WebServer.Handler
s.WebServer.Handler = s
var ln net.Listener
host, _, err := net.SplitHostPort(s.WebServer.Addr)
_, _, err := net.SplitHostPort(s.WebServer.Addr)
if err == nil {
ln, err = net.Listen("tcp", s.WebServer.Addr)
} else {
ports := s.Ports
if len(ports) == 0 {
ports = []int{0}
}
for _, one := range ports {
if ln, err = net.Listen("tcp", net.JoinHostPort(s.WebServer.Addr, strconv.Itoa(one))); err == nil {
break
}
}
ln, err = net.Listen("tcp", net.JoinHostPort(s.WebServer.Addr, "0"))
}
if err != nil {
return errs.Wrap(err)
}
listener := network.TCPKeepAliveListener{TCPListener: ln.(*net.TCPListener)}
var portStr string
if _, portStr, err = net.SplitHostPort(ln.Addr().String()); err != nil {
var host, portStr string
if host, portStr, err = net.SplitHostPort(ln.Addr().String()); err != nil {
return errs.Wrap(err)
}
if s.port, err = strconv.Atoi(portStr); err != nil {
return errs.Wrap(err)
}
s.addresses = network.AddressesForHost(host)
s.Logger.Info("listening", "protocol", s.Protocol(), "addresses", s.addresses, "port", s.port)
go func() {
if s.StartedChan != nil {
close(s.StartedChan)
}
}()
if s.StartedChan != nil {
go func() { close(s.StartedChan) }()
}
if s.Protocol() == ProtocolHTTPS {
err = s.WebServer.ServeTLS(listener, s.CertFile, s.KeyFile)
} else {
Expand All @@ -162,24 +133,58 @@ func (s *Server) Run() error {
return nil
}

// ServeHTTP implements the http.Handler interface.
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
started := time.Now()
sw := &StatusResponseWriter{
Original: w,
Head: r.Method == http.MethodHead,
}
md := &Metadata{Logger: s.Logger.With("method", r.Method, "url", r.URL)}
r = r.WithContext(context.WithValue(r.Context(), metadataKey, md))
defer func() {
if recovered := recover(); recovered != nil {
err, ok := recovered.(error)
if !ok {
err = errs.Newf("%+v", recovered)
}
errs.LogTo(md.Logger, errs.NewWithCause("recovered from panic in handler", err))
ErrorStatus(sw, http.StatusInternalServerError)
}
since := time.Since(started)
millis := int64(since / time.Millisecond)
micros := int64(since/time.Microsecond) - millis*1000
written := sw.BytesWritten()
md.Logger.Info("web", "status", sw.Status(), "bytes", written, "elapsed",
fmt.Sprintf("%d.%03dms", millis, micros))
}()
s.clientHandler.ServeHTTP(sw, r)
}

// Shutdown the server gracefully.
func (s *Server) Shutdown() {
startedAt := time.Now()
slog.Info("starting shutdown", "protocol", s.Protocol(), "addresses", s.addresses, "port", s.port)
defer func() {
slog.Info("finished shutdown", "protocol", s.Protocol(), "addresses", s.addresses, "port", s.port, "elapsed",
time.Since(startedAt))
}()
logger := s.Logger.With("protocol", s.Protocol(), "addresses", s.addresses, "port", s.port)
logger.Info("starting shutdown")
defer func() { logger.Info("finished shutdown", "elapsed", time.Since(startedAt)) }()
gracePeriod := s.ShutdownGracePeriod
if gracePeriod <= 0 {
gracePeriod = time.Minute
}
ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(gracePeriod))
defer cancel()
if err := s.WebServer.Shutdown(ctx); err != nil {
errs.LogTo(s.Logger, errs.NewWithCause("unable to shutdown gracefully", err), "protocol", s.Protocol())
errs.LogTo(logger, errs.NewWithCause("unable to shutdown gracefully", err))
}
if s.ShutdownCallback != nil {
s.ShutdownCallback()
s.ShutdownCallback(logger)
}
}

// MetadataFromRequest returns the Metadata from the request.
func MetadataFromRequest(req *http.Request) *Metadata {
if md, ok := req.Context().Value(metadataKey).(*Metadata); ok {
return md
}
return nil
}
22 changes: 0 additions & 22 deletions xio/network/xhttp/status.go

This file was deleted.

6 changes: 6 additions & 0 deletions xio/network/xhttp/status_response_writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,9 @@ func (w *StatusResponseWriter) Flush() {
f.Flush()
}
}

// ErrorStatus sends an HTTP response header with 'statusCode' and follows it with the standard text for that code as
// the body.
func ErrorStatus(w http.ResponseWriter, statusCode int) {
http.Error(w, http.StatusText(statusCode), statusCode)
}
78 changes: 0 additions & 78 deletions xio/network/xhttp/web/route.go

This file was deleted.

0 comments on commit d46053c

Please sign in to comment.