diff --git a/httpserver/httpserver.go b/httpserver/httpserver.go index 34adf1d..55445cd 100644 --- a/httpserver/httpserver.go +++ b/httpserver/httpserver.go @@ -9,50 +9,80 @@ import ( "time" ) +type HandleChainFunc func(http.Handler) http.Handler + // A http server that has an inbuilt logger, name and complies wuth the Listener interface in // startup.Listeners. type Server struct { - http.Server - log Logger - name string + log Logger + name string + server http.Server + handlers []HandleChainFunc +} + +type ServerOption func(*Server) + +// WithHandler adds a handler on the http endpoint. +func WithHandler(h HandleChainFunc) ServerOption { + return func(s *Server) { + if h != nil { + s.handlers = append(s.handlers, h) + } + } +} + +// WithHandlers adds a handler on the http endpoint. +func WithHandlers(h []HandleChainFunc) ServerOption { + return func(s *Server) { + s.handlers = append(s.handlers, h...) + } } -func New(log Logger, name string, port string, handler http.Handler) *Server { - log.Debugf("New HTTPServer %s", name) - m := Server{ - Server: http.Server{ - Addr: ":" + port, - Handler: handler, +func New(log Logger, name string, port string, h http.Handler, opts ...ServerOption) *Server { + s := Server{ + server: http.Server{ + Addr: ":" + port, }, name: strings.ToLower(name), } - m.log = log.WithIndex("httpserver", m.String()) + s.log = log.WithIndex("httpserver", s.String()) + for _, opt := range opts { + opt(&s) + } + + s.log.Debugf("Initialise handlers %v", h) + for _, handler := range s.handlers { + if handler != nil { + h = handler(h) + } + } + s.server.Handler = h + // It is preferable to return a copy rather than a reference. Unfortunately http.Server has an // internal mutex and this cannot or should not be copied so we will return a reference instead. - log.Debugf("HTTPServer") - return &m + return &s } -func (m *Server) String() string { +func (s *Server) String() string { // No logging here please - return fmt.Sprintf("%s%s", m.name, m.Addr) + return fmt.Sprintf("%s%s", s.name, s.server.Addr) } -func (m *Server) Listen() error { - m.log.Infof("Listen") - err := m.Server.ListenAndServe() +func (s *Server) Listen() error { + s.log.Infof("Listen") + err := s.server.ListenAndServe() if err != nil { - return fmt.Errorf("%s server terminated: %v", m, err) + return fmt.Errorf("%s server terminated: %v", s, err) } return nil } -func (m *Server) Shutdown(ctx context.Context) error { +func (s *Server) Shutdown(ctx context.Context) error { ctx, cancel := context.WithTimeout(ctx, 5*time.Second) defer cancel() - m.log.Infof("Shutdown") - err := m.Server.Shutdown(ctx) + s.log.Infof("Shutdown") + err := s.server.Shutdown(ctx) if err != nil && !errors.Is(err, http.ErrServerClosed) && !errors.Is(err, context.Canceled) { return err } diff --git a/restproxyserver/restproxyserver.go b/restproxyserver/restproxyserver.go index 8bcecce..c838ad8 100644 --- a/restproxyserver/restproxyserver.go +++ b/restproxyserver/restproxyserver.go @@ -22,11 +22,13 @@ const ( type Marshaler = runtime.Marshaler type ServeMux = runtime.ServeMux type QueryParameterParser = runtime.QueryParameterParser +type HeaderMatcherFunc = runtime.HeaderMatcherFunc +type ErrorHandlerFunc = runtime.ErrorHandlerFunc type DialOption = grpc.DialOption type RegisterRESTProxyServer func(context.Context, *ServeMux, string, []DialOption) error -type HandleFunc func(http.Handler) http.Handler +type HandleChainFunc = httpserver.HandleChainFunc type filePath struct { verb string @@ -44,9 +46,8 @@ type RESTProxyServer struct { dialOptions []DialOption options []runtime.ServeMuxOption filePaths []filePath - handlers []HandleFunc - register RegisterRESTProxyServer - health RegisterRESTProxyServer + handlers []HandleChainFunc + registers []RegisterRESTProxyServer server *httpserver.Server } @@ -60,7 +61,7 @@ func WithMarshaler(mime string, m Marshaler) RESTProxyServerOption { } // SetQueryParameterParser adds an intercepror that matches header values. -func SetQueryParameterParser(p runtime.QueryParameterParser) RESTProxyServerOption { +func SetQueryParameterParser(p QueryParameterParser) RESTProxyServerOption { return func(g *RESTProxyServer) { g.options = append(g.options, runtime.SetQueryParameterParser(p)) } @@ -68,21 +69,21 @@ func SetQueryParameterParser(p runtime.QueryParameterParser) RESTProxyServerOpti // WithOutgoingHeaderMatcher matches header values on oupput. // WithIncomingHeaderMatcher adds an intercepror that matches header values. -func WithIncomingHeaderMatcher(o runtime.HeaderMatcherFunc) RESTProxyServerOption { +func WithIncomingHeaderMatcher(o HeaderMatcherFunc) RESTProxyServerOption { return func(g *RESTProxyServer) { g.options = append(g.options, runtime.WithIncomingHeaderMatcher(o)) } } // WithOutgoingHeaderMatcher matches header values on oupput. -func WithOutgoingHeaderMatcher(o runtime.HeaderMatcherFunc) RESTProxyServerOption { +func WithOutgoingHeaderMatcher(o HeaderMatcherFunc) RESTProxyServerOption { return func(g *RESTProxyServer) { g.options = append(g.options, runtime.WithOutgoingHeaderMatcher(o)) } } // WithErrorHandler adds error handling in special cases - e.g on 402 or 429. -func WithErrorHandler(o runtime.ErrorHandlerFunc) RESTProxyServerOption { +func WithErrorHandler(o ErrorHandlerFunc) RESTProxyServerOption { return func(g *RESTProxyServer) { g.options = append(g.options, runtime.WithErrorHandler(o)) } @@ -95,15 +96,15 @@ func WithGRPCAddress(a string) RESTProxyServerOption { } } -// WikthHealthHandler adds another grpc-gateway - typically grpcHealth. -func WithHealthHandler(r RegisterRESTProxyServer) RESTProxyServerOption { +// WikthRegisterHandler adds another grpc-gateway handler +func WithRegisterHandler(r RegisterRESTProxyServer) RESTProxyServerOption { return func(g *RESTProxyServer) { - g.health = r + g.registers = append(g.registers, r) } } -// WithHandler adds a handler on the http endpoint. -func WithHandler(h HandleFunc) RESTProxyServerOption { +// WithHTTPHandler adds a handler on the http endpoint. +func WithHTTPHandler(h HandleChainFunc) RESTProxyServerOption { return func(g *RESTProxyServer) { if h != nil { g.handlers = append(g.handlers, h) @@ -141,19 +142,17 @@ func WithHandlePath(verb string, urlPath string, f func(http.ResponseWriter, *ht // New creates a new RESTProxyServer that is bound to a specific GRPC Gateway API. This object complies with // the standard Listener interface and can be managed by the startup.Listeners object. -func New(log Logger, name string, port string, r RegisterRESTProxyServer, opts ...RESTProxyServerOption) RESTProxyServer { +func New(log Logger, name string, port string, opts ...RESTProxyServerOption) RESTProxyServer { var err error - log.Debugf("New RESTPROXY Server %s", name) - g := RESTProxyServer{ name: strings.ToLower(name), port: port, - register: r, dialOptions: tracing.GRPCDialTracingOptions(), options: []runtime.ServeMuxOption{}, filePaths: []filePath{}, - handlers: []HandleFunc{}, + handlers: []HandleChainFunc{}, + registers: []RegisterRESTProxyServer{}, } g.log = log.WithIndex("restproxyserver", g.String()) for _, opt := range opts { @@ -165,32 +164,27 @@ func New(log Logger, name string, port string, r RegisterRESTProxyServer, opts . g.grpcAddress = fmt.Sprintf("localhost:%s", port) } - log.Debugf("RESTPROXY Server") - mux := runtime.NewServeMux(g.options...) for _, p := range g.filePaths { err = mux.HandlePath(p.verb, p.urlPath, p.fileHandler) if err != nil { - log.Panicf("cannot handle path %s: %w", p.urlPath, err) + g.log.Panicf("cannot handle path %s: %w", p.urlPath, err) } } - err = g.register(context.Background(), mux, g.grpcAddress, g.dialOptions) - if err != nil { - log.Panicf("register error: %w", err) - } - if g.health != nil { - err = g.health(context.Background(), mux, g.grpcAddress, g.dialOptions) + for _, register := range g.registers { + err = register(context.Background(), mux, g.grpcAddress, g.dialOptions) if err != nil { - log.Panicf("healthregister error: %w", err) + g.log.Panicf("register error: %w", err) } } - - var h http.Handler = mux - for _, handler := range g.handlers { - h = handler(h) - } - g.server = httpserver.New(g.log, fmt.Sprintf("proxy %s", g.name), g.port, h) + g.server = httpserver.New( + g.log, + fmt.Sprintf("proxy %s", g.name), + g.port, + mux, + httpserver.WithHandlers(g.handlers), + ) return g }